aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-01-04 13:32:14 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-04 13:32:14 -0800
commit93ef9b6a2aa1830170cb101f191022f2dda62c41 (patch)
treefae4dc34f0bc0bc12d3ee374e97c3090e1f7da90 /mllib/src/test/scala/org/apache
parentba5f81859d6ba37a228a1c43d26c47e64c0382cd (diff)
downloadspark-93ef9b6a2aa1830170cb101f191022f2dda62c41.tar.gz
spark-93ef9b6a2aa1830170cb101f191022f2dda62c41.tar.bz2
spark-93ef9b6a2aa1830170cb101f191022f2dda62c41.zip
[SPARK-9622][ML] DecisionTreeRegressor: provide variance of prediction
DecisionTreeRegressor will provide variance of prediction as a Double column. Author: Yanbo Liang <ybliang8@gmail.com> Closes #8866 from yanboliang/spark-9622.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala26
1 files changed, 25 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 6999a910c3..0b39af5543 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -20,12 +20,13 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
DecisionTreeSuite => OldDecisionTreeSuite}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{Row, DataFrame}
class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -73,6 +74,29 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
MLTestingUtils.checkCopy(model)
}
+ test("predictVariance") {
+ val dt = new DecisionTreeRegressor()
+ .setImpurity("variance")
+ .setMaxDepth(2)
+ .setMaxBins(100)
+ .setPredictionCol("")
+ .setVarianceCol("variance")
+ val categoricalFeatures = Map(0 -> 2, 1 -> 2)
+
+ val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0)
+ val model = dt.fit(df)
+
+ val predictions = model.transform(df)
+ .select(model.getFeaturesCol, model.getVarianceCol)
+ .collect()
+
+ predictions.foreach { case Row(features: Vector, variance: Double) =>
+ val expectedVariance = model.rootNode.predictImpl(features).impurityStats.calculate()
+ assert(variance === expectedVariance,
+ s"Expected variance $expectedVariance but got $variance.")
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////