aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/regression.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r--python/pyspark/ml/regression.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 00a6a0de90..f6c5d130dd 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -630,7 +630,7 @@ class GBTParams(TreeEnsembleParams):
@inherit_doc
class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval,
- HasSeed, JavaMLWritable, JavaMLReadable):
+ HasSeed, JavaMLWritable, JavaMLReadable, HasVarianceCol):
"""
`http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree`
learning algorithm for regression.
@@ -640,7 +640,7 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
>>> df = sqlContext.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
- >>> dt = DecisionTreeRegressor(maxDepth=2)
+ >>> dt = DecisionTreeRegressor(maxDepth=2, varianceCol="variance")
>>> model = dt.fit(df)
>>> model.depth
1
@@ -666,6 +666,8 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
True
>>> model.depth == model2.depth
True
+ >>> model.transform(test1).head().variance
+ 0.0
.. versionadded:: 1.4.0
"""
@@ -674,12 +676,12 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance",
- seed=None):
+ seed=None, varianceCol=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
- impurity="variance", seed=None)
+ impurity="variance", seed=None, varianceCol=None)
"""
super(DecisionTreeRegressor, self).__init__()
self._java_obj = self._new_java_obj(
@@ -695,12 +697,12 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
- impurity="variance", seed=None):
+ impurity="variance", seed=None, varianceCol=None):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
- impurity="variance", seed=None)
+ impurity="variance", seed=None, varianceCol=None)
Sets params for the DecisionTreeRegressor.
"""
kwargs = self.setParams._input_kwargs