diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2015-10-28 08:54:20 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-10-28 08:54:20 -0700 |
commit | f92b7b98e9998a6069996cc66ca26cbfa695fce5 (patch) | |
tree | dacb106997450f0287aae73bda12d33096b693b3 /python/pyspark/ml/param | |
parent | fba9e95452ca0a9b589bc14b27c750c69f482b8d (diff) | |
download | spark-f92b7b98e9998a6069996cc66ca26cbfa695fce5.tar.gz spark-f92b7b98e9998a6069996cc66ca26cbfa695fce5.tar.bz2 spark-f92b7b98e9998a6069996cc66ca26cbfa695fce5.zip |
[SPARK-11367][ML][PYSPARK] Python LinearRegression should support setting solver
[SPARK-10668](https://issues.apache.org/jira/browse/SPARK-10668) has provided ```WeightedLeastSquares``` solver("normal") in ```LinearRegression``` with L2 regularization in Scala and R, Python ML ```LinearRegression``` should also support setting solver("auto", "normal", "l-bfgs")
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #9328 from yanboliang/spark-11367.
Diffstat (limited to 'python/pyspark/ml/param')
-rw-r--r-- | python/pyspark/ml/param/_shared_params_code_gen.py | 4 | ||||
-rw-r--r-- | python/pyspark/ml/param/shared.py | 28 |
2 files changed, 31 insertions, 1 deletions
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 7143d56330..070c5db01a 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -135,7 +135,9 @@ if __name__ == "__main__": "values >= 0. The class with largest value p/t is predicted, where p is the original " + "probability of that class and t is the class' threshold.", None), ("weightCol", "weight column name. If this is not set or empty, we treat " + - "all instance weights as 1.0.", None)] + "all instance weights as 1.0.", None), + ("solver", "the solver algorithm for optimization. If this is not set or empty, " + + "default value is 'auto'.", "'auto'")] code = [] for name, doc, defaultValueStr in shared: diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 3a58ac87d6..4bdf2a8cc5 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -597,6 +597,34 @@ class HasWeightCol(Params): return self.getOrDefault(self.weightCol) +class HasSolver(Params): + """ + Mixin for param solver: the solver algorithm for optimization. If this is not set or empty, default value is 'auto'. + """ + + # a placeholder to make it appear in the generated doc + solver = Param(Params._dummy(), "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.") + + def __init__(self): + super(HasSolver, self).__init__() + #: param for the solver algorithm for optimization. If this is not set or empty, default value is 'auto'. + self.solver = Param(self, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.") + self._setDefault(solver='auto') + + def setSolver(self, value): + """ + Sets the value of :py:attr:`solver`. + """ + self._paramMap[self.solver] = value + return self + + def getSolver(self): + """ + Gets the value of solver or its default value. + """ + return self.getOrDefault(self.solver) + + class DecisionTreeParams(Params): """ Mixin for Decision Tree parameters. |