aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-08-25 02:26:33 -0700
committerYanbo Liang <ybliang8@gmail.com>2016-08-25 02:26:33 -0700
commit6b8cb1fe52e2c8b4b87b0c7d820f3a1824287328 (patch)
tree1e08638b25aedd2adcd9b84612a87cc095508a13
parente0b20f9f24d5c3304bf517a4dcfb0da93be5bc75 (diff)
downloadspark-6b8cb1fe52e2c8b4b87b0c7d820f3a1824287328.tar.gz
spark-6b8cb1fe52e2c8b4b87b0c7d820f3a1824287328.tar.bz2
spark-6b8cb1fe52e2c8b4b87b0c7d820f3a1824287328.zip
[SPARK-17197][ML][PYSPARK] PySpark LiR/LoR supports tree aggregation level configurable.
## What changes were proposed in this pull request? [SPARK-17090](https://issues.apache.org/jira/browse/SPARK-17090) makes tree aggregation level in LiR/LoR configurable, this PR makes PySpark support this function. ## How was this patch tested? Since ```aggregationDepth``` is an expert param, I'm not prefer to test it in doctest which is also used for example. Here is the offline test result: ![image](https://cloud.githubusercontent.com/assets/1962026/17879457/f83d7760-68a6-11e6-9936-d0a884d5d6ec.png) Author: Yanbo Liang <ybliang8@gmail.com> Closes #14766 from yanboliang/spark-17197.
-rw-r--r--python/pyspark/ml/classification.py14
-rw-r--r--python/pyspark/ml/param/_shared_params_code_gen.py4
-rw-r--r--python/pyspark/ml/param/shared.py24
-rw-r--r--python/pyspark/ml/regression.py11
4 files changed, 42 insertions, 11 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 33ada27454..d1522d78fa 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -64,7 +64,7 @@ class JavaClassificationModel(JavaPredictionModel):
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds,
- HasWeightCol, JavaMLWritable, JavaMLReadable):
+ HasWeightCol, HasAggregationDepth, JavaMLWritable, JavaMLReadable):
"""
Logistic regression.
Currently, this class only supports binary classification.
@@ -121,12 +121,14 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
threshold=0.5, thresholds=None, probabilityCol="probability",
- rawPredictionCol="rawPrediction", standardization=True, weightCol=None):
+ rawPredictionCol="rawPrediction", standardization=True, weightCol=None,
+ aggregationDepth=2):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
threshold=0.5, thresholds=None, probabilityCol="probability", \
- rawPredictionCol="rawPrediction", standardization=True, weightCol=None)
+ rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \
+ aggregationDepth=2)
If the threshold and thresholds Params are both set, they must be equivalent.
"""
super(LogisticRegression, self).__init__()
@@ -142,12 +144,14 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
threshold=0.5, thresholds=None, probabilityCol="probability",
- rawPredictionCol="rawPrediction", standardization=True, weightCol=None):
+ rawPredictionCol="rawPrediction", standardization=True, weightCol=None,
+ aggregationDepth=2):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
threshold=0.5, thresholds=None, probabilityCol="probability", \
- rawPredictionCol="rawPrediction", standardization=True, weightCol=None)
+ rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \
+ aggregationDepth=2)
Sets params for logistic regression.
If the threshold and thresholds Params are both set, they must be equivalent.
"""
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index c32dcc467d..4f4328bcad 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -147,7 +147,9 @@ if __name__ == "__main__":
("solver", "the solver algorithm for optimization. If this is not set or empty, " +
"default value is 'auto'.", "'auto'", "TypeConverters.toString"),
("varianceCol", "column name for the biased sample variance of prediction.",
- None, "TypeConverters.toString")]
+ None, "TypeConverters.toString"),
+ ("aggregationDepth", "suggested depth for treeAggregate (>= 2).", "2",
+ "TypeConverters.toInt")]
code = []
for name, doc, defaultValueStr, typeConverter in shared:
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index c5ccf81540..24af07afc7 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -560,6 +560,30 @@ class HasVarianceCol(Params):
return self.getOrDefault(self.varianceCol)
+class HasAggregationDepth(Params):
+ """
+ Mixin for param aggregationDepth: suggested depth for treeAggregate (>= 2).
+ """
+
+ aggregationDepth = Param(Params._dummy(), "aggregationDepth", "suggested depth for treeAggregate (>= 2).", typeConverter=TypeConverters.toInt)
+
+ def __init__(self):
+ super(HasAggregationDepth, self).__init__()
+ self._setDefault(aggregationDepth=2)
+
+ def setAggregationDepth(self, value):
+ """
+ Sets the value of :py:attr:`aggregationDepth`.
+ """
+ return self._set(aggregationDepth=value)
+
+ def getAggregationDepth(self):
+ """
+ Gets the value of aggregationDepth or its default value.
+ """
+ return self.getOrDefault(self.aggregationDepth)
+
+
class DecisionTreeParams(Params):
"""
Mixin for Decision Tree parameters.
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 56312f672f..19afc723bb 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -39,7 +39,8 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
@inherit_doc
class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept,
- HasStandardization, HasSolver, HasWeightCol, JavaMLWritable, JavaMLReadable):
+ HasStandardization, HasSolver, HasWeightCol, HasAggregationDepth,
+ JavaMLWritable, JavaMLReadable):
"""
Linear regression.
@@ -97,11 +98,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
- standardization=True, solver="auto", weightCol=None):
+ standardization=True, solver="auto", weightCol=None, aggregationDepth=2):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
- standardization=True, solver="auto", weightCol=None)
+ standardization=True, solver="auto", weightCol=None, aggregationDepth=2)
"""
super(LinearRegression, self).__init__()
self._java_obj = self._new_java_obj(
@@ -114,11 +115,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
@since("1.4.0")
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
- standardization=True, solver="auto", weightCol=None):
+ standardization=True, solver="auto", weightCol=None, aggregationDepth=2):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
- standardization=True, solver="auto", weightCol=None)
+ standardization=True, solver="auto", weightCol=None, aggregationDepth=2)
Sets params for linear regression.
"""
kwargs = self.setParams._input_kwargs