diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-01-04 13:32:14 -0800 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-01-04 13:32:14 -0800 |
commit | 93ef9b6a2aa1830170cb101f191022f2dda62c41 (patch) | |
tree | fae4dc34f0bc0bc12d3ee374e97c3090e1f7da90 /mllib/src/test | |
parent | ba5f81859d6ba37a228a1c43d26c47e64c0382cd (diff) | |
download | spark-93ef9b6a2aa1830170cb101f191022f2dda62c41.tar.gz spark-93ef9b6a2aa1830170cb101f191022f2dda62c41.tar.bz2 spark-93ef9b6a2aa1830170cb101f191022f2dda62c41.zip |
[SPARK-9622][ML] DecisionTreeRegressor: provide variance of prediction
DecisionTreeRegressor will provide variance of prediction as a Double column.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #8866 from yanboliang/spark-9622.
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala | 26 |
1 files changed, 25 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 6999a910c3..0b39af5543 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -20,12 +20,13 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{Row, DataFrame} class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -73,6 +74,29 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex MLTestingUtils.checkCopy(model) } + test("predictVariance") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(100) + .setPredictionCol("") + .setVarianceCol("variance") + val categoricalFeatures = Map(0 -> 2, 1 -> 2) + + val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) + val model = dt.fit(df) + + val predictions = model.transform(df) + .select(model.getFeaturesCol, model.getVarianceCol) + .collect() + + predictions.foreach { case Row(features: Vector, variance: Double) => + val expectedVariance = model.rootNode.predictImpl(features).impurityStats.calculate() + assert(variance === expectedVariance, + s"Expected variance $expectedVariance but got $variance.") + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// |