From 7d432af8f3c47973550ea253dae0c23cd2961bde Mon Sep 17 00:00:00 2001 From: 颜发才(Yan Facai) Date: Tue, 28 Mar 2017 16:14:01 -0700 Subject: [SPARK-20043][ML] DecisionTreeModel: ImpurityCalculator builder fails for uppercase impurity type Gini MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix bug: DecisionTreeModel can't recongnize Impurity "Gini" when loading TODO: + [x] add unit test + [x] fix the bug Author: 颜发才(Yan Facai) Closes #17407 from facaiy/BUG/decision_tree_loader_failer_with_Gini_impurity. --- .../org/apache/spark/mllib/tree/impurity/Impurity.scala | 2 +- .../ml/classification/DecisionTreeClassifierSuite.scala | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index a5bdc2c6d2..98a3021461 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -184,7 +184,7 @@ private[spark] object ImpurityCalculator { * the given stats. */ def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = { - impurity match { + impurity.toLowerCase match { case "gini" => new GiniCalculator(stats) case "entropy" => new EntropyCalculator(stats) case "variance" => new VarianceCalculator(stats) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 10de50306a..964fcfbdd8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -385,6 +385,20 @@ class DecisionTreeClassifierSuite testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0), allParamSettings ++ Map("maxDepth" -> 0), checkModelData) } + + test("SPARK-20043: " + + "ImpurityCalculator builder fails for uppercase impurity type Gini in model read/write") { + val rdd = TreeTests.getTreeReadWriteData(sc) + val data: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) + + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + val model = dt.fit(data) + + testDefaultReadWrite(model) + } } private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { -- cgit v1.2.3