diff options
author | 颜发才(Yan Facai) <facai.yan@gmail.com> | 2017-03-28 16:14:01 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2017-03-28 16:14:01 -0700 |
commit | 7d432af8f3c47973550ea253dae0c23cd2961bde (patch) | |
tree | 95a8546b688bd5aed191c9d9061ac8c96e41f51e /mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala | |
parent | 92e385e0b55d70a48411e90aa0f2ed141c4d07c8 (diff) | |
download | spark-7d432af8f3c47973550ea253dae0c23cd2961bde.tar.gz spark-7d432af8f3c47973550ea253dae0c23cd2961bde.tar.bz2 spark-7d432af8f3c47973550ea253dae0c23cd2961bde.zip |
[SPARK-20043][ML] DecisionTreeModel: ImpurityCalculator builder fails for uppercase impurity type Gini
Fix bug: DecisionTreeModel can't recongnize Impurity "Gini" when loading
TODO:
+ [x] add unit test
+ [x] fix the bug
Author: 颜发才(Yan Facai) <facai.yan@gmail.com>
Closes #17407 from facaiy/BUG/decision_tree_loader_failer_with_Gini_impurity.
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 10de50306a..964fcfbdd8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -385,6 +385,20 @@ class DecisionTreeClassifierSuite testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0), allParamSettings ++ Map("maxDepth" -> 0), checkModelData) } + + test("SPARK-20043: " + + "ImpurityCalculator builder fails for uppercase impurity type Gini in model read/write") { + val rdd = TreeTests.getTreeReadWriteData(sc) + val data: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) + + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + val model = dt.fit(data) + + testDefaultReadWrite(model) + } } private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { |