diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 2b07524815..fe839e15e9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -27,8 +27,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Row +import org.apache.spark.sql.{DataFrame, Row} class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -176,7 +175,7 @@ class DecisionTreeClassifierSuite } test("Multiclass classification tree with 10-ary (ordered) categorical features," + - " with just enough bins") { + " with just enough bins") { val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD val dt = new DecisionTreeClassifier() .setImpurity("Gini") @@ -273,7 +272,7 @@ class DecisionTreeClassifierSuite )) val df = TreeTests.setMetadata(data, Map(0 -> 1), 2) val dt = new DecisionTreeClassifier().setMaxDepth(3) - val model = dt.fit(df) + dt.fit(df) } test("Use soft prediction for binary classification with ordered categorical features") { @@ -335,6 +334,14 @@ class DecisionTreeClassifierSuite assert(importances.toArray.forall(_ >= 0.0)) } + test("should support all NumericType labels and not support other types") { + val dt = new DecisionTreeClassifier().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[DecisionTreeClassificationModel, DecisionTreeClassifier]( + dt, isClassification = true, sqlContext) { (expected, actual) => + TreeTests.checkEqual(expected, actual) + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// |