aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-10-28 08:54:20 -0700
committerXiangrui Meng <meng@databricks.com>2015-10-28 08:54:20 -0700
commitf92b7b98e9998a6069996cc66ca26cbfa695fce5 (patch)
treedacb106997450f0287aae73bda12d33096b693b3 /python
parentfba9e95452ca0a9b589bc14b27c750c69f482b8d (diff)
downloadspark-f92b7b98e9998a6069996cc66ca26cbfa695fce5.tar.gz
spark-f92b7b98e9998a6069996cc66ca26cbfa695fce5.tar.bz2
spark-f92b7b98e9998a6069996cc66ca26cbfa695fce5.zip
[SPARK-11367][ML][PYSPARK] Python LinearRegression should support setting solver
[SPARK-10668](https://issues.apache.org/jira/browse/SPARK-10668) has provided ```WeightedLeastSquares``` solver("normal") in ```LinearRegression``` with L2 regularization in Scala and R, Python ML ```LinearRegression``` should also support setting solver("auto", "normal", "l-bfgs") Author: Yanbo Liang <ybliang8@gmail.com> Closes #9328 from yanboliang/spark-11367.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/param/_shared_params_code_gen.py4
-rw-r--r--python/pyspark/ml/param/shared.py28
-rw-r--r--python/pyspark/ml/regression.py27
3 files changed, 37 insertions, 22 deletions
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index 7143d56330..070c5db01a 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -135,7 +135,9 @@ if __name__ == "__main__":
"values >= 0. The class with largest value p/t is predicted, where p is the original " +
"probability of that class and t is the class' threshold.", None),
("weightCol", "weight column name. If this is not set or empty, we treat " +
- "all instance weights as 1.0.", None)]
+ "all instance weights as 1.0.", None),
+ ("solver", "the solver algorithm for optimization. If this is not set or empty, " +
+ "default value is 'auto'.", "'auto'")]
code = []
for name, doc, defaultValueStr in shared:
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index 3a58ac87d6..4bdf2a8cc5 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -597,6 +597,34 @@ class HasWeightCol(Params):
return self.getOrDefault(self.weightCol)
+class HasSolver(Params):
+ """
+ Mixin for param solver: the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ solver = Param(Params._dummy(), "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.")
+
+ def __init__(self):
+ super(HasSolver, self).__init__()
+ #: param for the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.
+ self.solver = Param(self, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.")
+ self._setDefault(solver='auto')
+
+ def setSolver(self, value):
+ """
+ Sets the value of :py:attr:`solver`.
+ """
+ self._paramMap[self.solver] = value
+ return self
+
+ def getSolver(self):
+ """
+ Gets the value of solver or its default value.
+ """
+ return self.getOrDefault(self.solver)
+
+
class DecisionTreeParams(Params):
"""
Mixin for Decision Tree parameters.
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index eeb18b3e9d..dc68815556 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -33,7 +33,7 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
@inherit_doc
class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept,
- HasStandardization):
+ HasStandardization, HasSolver):
"""
Linear regression.
@@ -50,7 +50,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
>>> df = sqlContext.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
- >>> lr = LinearRegression(maxIter=5, regParam=0.0)
+ >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal")
>>> model = lr.fit(df)
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
@@ -73,11 +73,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
- standardization=True):
+ standardization=True, solver="auto"):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
- standardization=True)
+ standardization=True, solver="auto")
"""
super(LinearRegression, self).__init__()
self._java_obj = self._new_java_obj(
@@ -90,11 +90,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
@since("1.4.0")
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
- standardization=True):
+ standardization=True, solver="auto"):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
- standardization=True)
+ standardization=True, solver="auto")
Sets params for linear regression.
"""
kwargs = self.setParams._input_kwargs
@@ -103,21 +103,6 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
def _create_model(self, java_model):
return LinearRegressionModel(java_model)
- @since("1.4.0")
- def setElasticNetParam(self, value):
- """
- Sets the value of :py:attr:`elasticNetParam`.
- """
- self._paramMap[self.elasticNetParam] = value
- return self
-
- @since("1.4.0")
- def getElasticNetParam(self):
- """
- Gets the value of elasticNetParam or its default value.
- """
- return self.getOrDefault(self.elasticNetParam)
-
class LinearRegressionModel(JavaModel):
"""