aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala14
1 files changed, 14 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 10de50306a..964fcfbdd8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -385,6 +385,20 @@ class DecisionTreeClassifierSuite
testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0),
allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
}
+
+ test("SPARK-20043: " +
+ "ImpurityCalculator builder fails for uppercase impurity type Gini in model read/write") {
+ val rdd = TreeTests.getTreeReadWriteData(sc)
+ val data: DataFrame =
+ TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
+
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(2)
+ val model = dt.fit(data)
+
+ testDefaultReadWrite(model)
+ }
}
private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {