diff options
Diffstat (limited to 'python/pyspark/ml')
-rw-r--r-- | python/pyspark/ml/param/_shared_params_code_gen.py | 4 | ||||
-rw-r--r-- | python/pyspark/ml/param/shared.py | 28 | ||||
-rw-r--r-- | python/pyspark/ml/regression.py | 27 |
3 files changed, 37 insertions, 22 deletions
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 7143d56330..070c5db01a 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -135,7 +135,9 @@ if __name__ == "__main__": "values >= 0. The class with largest value p/t is predicted, where p is the original " + "probability of that class and t is the class' threshold.", None), ("weightCol", "weight column name. If this is not set or empty, we treat " + - "all instance weights as 1.0.", None)] + "all instance weights as 1.0.", None), + ("solver", "the solver algorithm for optimization. If this is not set or empty, " + + "default value is 'auto'.", "'auto'")] code = [] for name, doc, defaultValueStr in shared: diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 3a58ac87d6..4bdf2a8cc5 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -597,6 +597,34 @@ class HasWeightCol(Params): return self.getOrDefault(self.weightCol) +class HasSolver(Params): + """ + Mixin for param solver: the solver algorithm for optimization. If this is not set or empty, default value is 'auto'. + """ + + # a placeholder to make it appear in the generated doc + solver = Param(Params._dummy(), "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.") + + def __init__(self): + super(HasSolver, self).__init__() + #: param for the solver algorithm for optimization. If this is not set or empty, default value is 'auto'. + self.solver = Param(self, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.") + self._setDefault(solver='auto') + + def setSolver(self, value): + """ + Sets the value of :py:attr:`solver`. + """ + self._paramMap[self.solver] = value + return self + + def getSolver(self): + """ + Gets the value of solver or its default value. + """ + return self.getOrDefault(self.solver) + + class DecisionTreeParams(Params): """ Mixin for Decision Tree parameters. diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index eeb18b3e9d..dc68815556 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -33,7 +33,7 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel', @inherit_doc class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, - HasStandardization): + HasStandardization, HasSolver): """ Linear regression. @@ -50,7 +50,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction >>> df = sqlContext.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> lr = LinearRegression(maxIter=5, regParam=0.0) + >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal") >>> model = lr.fit(df) >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction @@ -73,11 +73,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - standardization=True): + standardization=True, solver="auto"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - standardization=True) + standardization=True, solver="auto") """ super(LinearRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -90,11 +90,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - standardization=True): + standardization=True, solver="auto"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - standardization=True) + standardization=True, solver="auto") Sets params for linear regression. """ kwargs = self.setParams._input_kwargs @@ -103,21 +103,6 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction def _create_model(self, java_model): return LinearRegressionModel(java_model) - @since("1.4.0") - def setElasticNetParam(self, value): - """ - Sets the value of :py:attr:`elasticNetParam`. - """ - self._paramMap[self.elasticNetParam] = value - return self - - @since("1.4.0") - def getElasticNetParam(self): - """ - Gets the value of elasticNetParam or its default value. - """ - return self.getOrDefault(self.elasticNetParam) - class LinearRegressionModel(JavaModel): """ |