diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 662e3fc679..e9fb2677b2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -117,6 +117,14 @@ class DecisionTreeRegressorSuite assert(importances.toArray.forall(_ >= 0.0)) } + test("should support all NumericType labels and not support other types") { + val dt = new DecisionTreeRegressor().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor]( + dt, isClassification = false, sqlContext) { (expected, actual) => + TreeTests.checkEqual(expected, actual) + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// |