aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorMechCoder <mks542@nyu.edu>2016-07-06 02:54:44 -0700
committerYanbo Liang <ybliang8@gmail.com>2016-07-06 02:54:44 -0700
commit909c6d812f6ca3a3305e4611a700c8c17905b953 (patch)
treeb6583d444403d8aaf9fbc61eab5177ddf66a9b9c /mllib
parent7e28fabdff2da1cc374efbf43372d92ae0cd07aa (diff)
downloadspark-909c6d812f6ca3a3305e4611a700c8c17905b953.tar.gz
spark-909c6d812f6ca3a3305e4611a700c8c17905b953.tar.bz2
spark-909c6d812f6ca3a3305e4611a700c8c17905b953.zip
[SPARK-16307][ML] Add test to verify the predicted variances of a DT on toy data
## What changes were proposed in this pull request? The current tests assumes that `impurity.calculate()` returns the variance correctly. It should be better to make the tests independent of this assumption. In other words verify that the variance computed equals the variance computed manually on a small tree. ## How was this patch tested? The patch is a test.... Author: MechCoder <mks542@nyu.edu> Closes #13981 from MechCoder/dt_variance.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala20
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala12
2 files changed, 32 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 9afb742406..15fa26e8b5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -22,6 +22,7 @@ import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
DecisionTreeSuite => OldDecisionTreeSuite}
@@ -96,6 +97,25 @@ class DecisionTreeRegressorSuite
assert(variance === expectedVariance,
s"Expected variance $expectedVariance but got $variance.")
}
+
+ val varianceData: RDD[LabeledPoint] = TreeTests.varianceData(sc)
+ val varianceDF = TreeTests.setMetadata(varianceData, Map.empty[Int, Int], 0)
+ dt.setMaxDepth(1)
+ .setMaxBins(6)
+ .setSeed(0)
+ val transformVarDF = dt.fit(varianceDF).transform(varianceDF)
+ val calculatedVariances = transformVarDF.select(dt.getVarianceCol).collect().map {
+ case Row(variance: Double) => variance
+ }
+
+ // Since max depth is set to 1, the best split point is that which splits the data
+ // into (0.0, 1.0, 2.0) and (10.0, 12.0, 14.0). The predicted variance for each
+ // data point in the left node is 0.667 and for each data point in the right node
+ // is 2.667
+ val expectedVariances = Array(0.667, 0.667, 0.667, 2.667, 2.667, 2.667)
+ calculatedVariances.zip(expectedVariances).foreach { case (actual, expected) =>
+ assert(actual ~== expected absTol 1e-3)
+ }
}
test("Feature importance with toy data") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
index d2fa8d0d63..c90cb8ca10 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
@@ -183,6 +183,18 @@ private[ml] object TreeTests extends SparkFunSuite {
))
/**
+ * Create some toy data for testing correctness of variance.
+ */
+ def varianceData(sc: SparkContext): RDD[LabeledPoint] = sc.parallelize(Seq(
+ new LabeledPoint(1.0, Vectors.dense(Array(0.0))),
+ new LabeledPoint(2.0, Vectors.dense(Array(1.0))),
+ new LabeledPoint(3.0, Vectors.dense(Array(2.0))),
+ new LabeledPoint(10.0, Vectors.dense(Array(3.0))),
+ new LabeledPoint(12.0, Vectors.dense(Array(4.0))),
+ new LabeledPoint(14.0, Vectors.dense(Array(5.0)))
+ ))
+
+ /**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
* This excludes input columns to simplify some tests.