aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjohnnywalleye <jsondag@gmail.com>2014-07-09 11:06:34 -0700
committerXiangrui Meng <meng@databricks.com>2014-07-09 11:06:34 -0700
commitd35e3db2325931492b64890125a70579bc3b587b (patch)
tree0f2c65c980cf01ae194e6847f1060538cfc65106
parent0eb11527d13083ced215e3fda44ed849198a57cb (diff)
downloadspark-d35e3db2325931492b64890125a70579bc3b587b.tar.gz
spark-d35e3db2325931492b64890125a70579bc3b587b.tar.bz2
spark-d35e3db2325931492b64890125a70579bc3b587b.zip
[SPARK-2417][MLlib] Fix DecisionTree tests
Fixes test failures introduced by https://github.com/apache/spark/pull/1316. For both the regression and classification cases, val stats is the InformationGainStats for the best tree split. stats.predict is the predicted value for the data, before the split is made. Since 600 of the 1,000 values generated by DecisionTreeSuite.generateCategoricalDataPoints() are 1.0 and the rest 0.0, the regression tree and classification tree both correctly predict a value of 0.6 for this data now, and the assertions have been changed to reflect that. Author: johnnywalleye <jsondag@gmail.com> Closes #1343 from johnnywalleye/decision-tree-tests and squashes the following commits: ef80603 [johnnywalleye] [SPARK-2417][MLlib] Fix DecisionTree tests
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala8
1 files changed, 4 insertions, 4 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 35e92d71dc..bcb11876b8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -253,8 +253,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val stats = bestSplits(0)._2
assert(stats.gain > 0)
- assert(stats.predict > 0.4)
- assert(stats.predict < 0.5)
+ assert(stats.predict > 0.5)
+ assert(stats.predict < 0.7)
assert(stats.impurity > 0.2)
}
@@ -280,8 +280,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val stats = bestSplits(0)._2
assert(stats.gain > 0)
- assert(stats.predict > 0.4)
- assert(stats.predict < 0.5)
+ assert(stats.predict > 0.5)
+ assert(stats.predict < 0.7)
assert(stats.impurity > 0.2)
}