aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala15
1 files changed, 8 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 50ac96eb5e..c04c416aaf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -33,7 +33,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
@@ -83,7 +83,7 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val
/** @group setParam */
def setVarianceCol(value: String): this.type = set(varianceCol, value)
- override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = {
+ override protected def train(dataset: Dataset[_]): DecisionTreeRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
@@ -158,15 +158,16 @@ final class DecisionTreeRegressionModel private[ml] (
rootNode.predictImpl(features).impurityStats.calculate()
}
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
transformImpl(dataset)
}
- override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val predictUDF = udf { (features: Vector) => predict(features) }
val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) }
- var output = dataset
+ var output = dataset.toDF
if ($(predictionCol).nonEmpty) {
output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
@@ -203,9 +204,9 @@ final class DecisionTreeRegressionModel private[ml] (
* to determine feature importance instead.
*/
@Since("2.0.0")
- lazy val featureImportances: Vector = RandomForest.featureImportances(this, numFeatures)
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures)
- /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */
+ /** Convert to spark.mllib DecisionTreeModel (losing some information) */
override private[spark] def toOld: OldDecisionTreeModel = {
new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression)
}