diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 50ac96eb5e..c04c416aaf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ @@ -83,7 +83,7 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val /** @group setParam */ def setVarianceCol(value: String): this.type = set(varianceCol, value) - override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = { + override protected def train(dataset: Dataset[_]): DecisionTreeRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) @@ -158,15 +158,16 @@ final class DecisionTreeRegressionModel private[ml] ( rootNode.predictImpl(features).impurityStats.calculate() } - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) transformImpl(dataset) } - override protected def transformImpl(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val predictUDF = udf { (features: Vector) => predict(features) } val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) } - var output = dataset + var output = dataset.toDF if ($(predictionCol).nonEmpty) { output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } @@ -203,9 +204,9 @@ final class DecisionTreeRegressionModel private[ml] ( * to determine feature importance instead. */ @Since("2.0.0") - lazy val featureImportances: Vector = RandomForest.featureImportances(this, numFeatures) + lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures) - /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */ + /** Convert to spark.mllib DecisionTreeModel (losing some information) */ override private[spark] def toOld: OldDecisionTreeModel = { new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression) } |