aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/ml/classification.py14
-rw-r--r--python/pyspark/ml/param/_shared_params_code_gen.py4
-rw-r--r--python/pyspark/ml/param/shared.py24
-rw-r--r--python/pyspark/ml/regression.py11
4 files changed, 42 insertions, 11 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 33ada27454..d1522d78fa 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -64,7 +64,7 @@ class JavaClassificationModel(JavaPredictionModel):
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds,
- HasWeightCol, JavaMLWritable, JavaMLReadable):
+ HasWeightCol, HasAggregationDepth, JavaMLWritable, JavaMLReadable):
"""
Logistic regression.
Currently, this class only supports binary classification.
@@ -121,12 +121,14 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
threshold=0.5, thresholds=None, probabilityCol="probability",
- rawPredictionCol="rawPrediction", standardization=True, weightCol=None):
+ rawPredictionCol="rawPrediction", standardization=True, weightCol=None,
+ aggregationDepth=2):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
threshold=0.5, thresholds=None, probabilityCol="probability", \
- rawPredictionCol="rawPrediction", standardization=True, weightCol=None)
+ rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \
+ aggregationDepth=2)
If the threshold and thresholds Params are both set, they must be equivalent.
"""
super(LogisticRegression, self).__init__()
@@ -142,12 +144,14 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
threshold=0.5, thresholds=None, probabilityCol="probability",
- rawPredictionCol="rawPrediction", standardization=True, weightCol=None):
+ rawPredictionCol="rawPrediction", standardization=True, weightCol=None,
+ aggregationDepth=2):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
threshold=0.5, thresholds=None, probabilityCol="probability", \
- rawPredictionCol="rawPrediction", standardization=True, weightCol=None)
+ rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \
+ aggregationDepth=2)
Sets params for logistic regression.
If the threshold and thresholds Params are both set, they must be equivalent.
"""
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index c32dcc467d..4f4328bcad 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -147,7 +147,9 @@ if __name__ == "__main__":
("solver", "the solver algorithm for optimization. If this is not set or empty, " +
"default value is 'auto'.", "'auto'", "TypeConverters.toString"),
("varianceCol", "column name for the biased sample variance of prediction.",
- None, "TypeConverters.toString")]
+ None, "TypeConverters.toString"),
+ ("aggregationDepth", "suggested depth for treeAggregate (>= 2).", "2",
+ "TypeConverters.toInt")]
code = []
for name, doc, defaultValueStr, typeConverter in shared:
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index c5ccf81540..24af07afc7 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -560,6 +560,30 @@ class HasVarianceCol(Params):
return self.getOrDefault(self.varianceCol)
+class HasAggregationDepth(Params):
+ """
+ Mixin for param aggregationDepth: suggested depth for treeAggregate (>= 2).
+ """
+
+ aggregationDepth = Param(Params._dummy(), "aggregationDepth", "suggested depth for treeAggregate (>= 2).", typeConverter=TypeConverters.toInt)
+
+ def __init__(self):
+ super(HasAggregationDepth, self).__init__()
+ self._setDefault(aggregationDepth=2)
+
+ def setAggregationDepth(self, value):
+ """
+ Sets the value of :py:attr:`aggregationDepth`.
+ """
+ return self._set(aggregationDepth=value)
+
+ def getAggregationDepth(self):
+ """
+ Gets the value of aggregationDepth or its default value.
+ """
+ return self.getOrDefault(self.aggregationDepth)
+
+
class DecisionTreeParams(Params):
"""
Mixin for Decision Tree parameters.
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 56312f672f..19afc723bb 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -39,7 +39,8 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
@inherit_doc
class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept,
- HasStandardization, HasSolver, HasWeightCol, JavaMLWritable, JavaMLReadable):
+ HasStandardization, HasSolver, HasWeightCol, HasAggregationDepth,
+ JavaMLWritable, JavaMLReadable):
"""
Linear regression.
@@ -97,11 +98,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
- standardization=True, solver="auto", weightCol=None):
+ standardization=True, solver="auto", weightCol=None, aggregationDepth=2):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
- standardization=True, solver="auto", weightCol=None)
+ standardization=True, solver="auto", weightCol=None, aggregationDepth=2)
"""
super(LinearRegression, self).__init__()
self._java_obj = self._new_java_obj(
@@ -114,11 +115,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
@since("1.4.0")
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
- standardization=True, solver="auto", weightCol=None):
+ standardization=True, solver="auto", weightCol=None, aggregationDepth=2):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
- standardization=True, solver="auto", weightCol=None)
+ standardization=True, solver="auto", weightCol=None, aggregationDepth=2)
Sets params for linear regression.
"""
kwargs = self.setParams._input_kwargs