aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala15
1 files changed, 11 insertions, 4 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 2b07524815..fe839e15e9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -27,8 +27,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{DataFrame, Row}
class DecisionTreeClassifierSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -176,7 +175,7 @@ class DecisionTreeClassifierSuite
}
test("Multiclass classification tree with 10-ary (ordered) categorical features," +
- " with just enough bins") {
+ " with just enough bins") {
val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD
val dt = new DecisionTreeClassifier()
.setImpurity("Gini")
@@ -273,7 +272,7 @@ class DecisionTreeClassifierSuite
))
val df = TreeTests.setMetadata(data, Map(0 -> 1), 2)
val dt = new DecisionTreeClassifier().setMaxDepth(3)
- val model = dt.fit(df)
+ dt.fit(df)
}
test("Use soft prediction for binary classification with ordered categorical features") {
@@ -335,6 +334,14 @@ class DecisionTreeClassifierSuite
assert(importances.toArray.forall(_ >= 0.0))
}
+ test("should support all NumericType labels and not support other types") {
+ val dt = new DecisionTreeClassifier().setMaxDepth(1)
+ MLTestingUtils.checkNumericTypes[DecisionTreeClassificationModel, DecisionTreeClassifier](
+ dt, isClassification = true, sqlContext) { (expected, actual) =>
+ TreeTests.checkEqual(expected, actual)
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////