diff options
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r-- | python/pyspark/ml/regression.py | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 00a6a0de90..f6c5d130dd 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -630,7 +630,7 @@ class GBTParams(TreeEnsembleParams): @inherit_doc class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval, - HasSeed, JavaMLWritable, JavaMLReadable): + HasSeed, JavaMLWritable, JavaMLReadable, HasVarianceCol): """ `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` learning algorithm for regression. @@ -640,7 +640,7 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> df = sqlContext.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> dt = DecisionTreeRegressor(maxDepth=2) + >>> dt = DecisionTreeRegressor(maxDepth=2, varianceCol="variance") >>> model = dt.fit(df) >>> model.depth 1 @@ -666,6 +666,8 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi True >>> model.depth == model2.depth True + >>> model.transform(test1).head().variance + 0.0 .. versionadded:: 1.4.0 """ @@ -674,12 +676,12 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance", - seed=None): + seed=None, varianceCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - impurity="variance", seed=None) + impurity="variance", seed=None, varianceCol=None) """ super(DecisionTreeRegressor, self).__init__() self._java_obj = self._new_java_obj( @@ -695,12 +697,12 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - impurity="variance", seed=None): + impurity="variance", seed=None, varianceCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - impurity="variance", seed=None) + impurity="variance", seed=None, varianceCol=None) Sets params for the DecisionTreeRegressor. """ kwargs = self.setParams._input_kwargs |