aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
author颜发才(Yan Facai) <facai.yan@gmail.com>2017-03-28 16:14:01 -0700
committerJoseph K. Bradley <joseph@databricks.com>2017-03-28 16:14:01 -0700
commit7d432af8f3c47973550ea253dae0c23cd2961bde (patch)
tree95a8546b688bd5aed191c9d9061ac8c96e41f51e /mllib
parent92e385e0b55d70a48411e90aa0f2ed141c4d07c8 (diff)
downloadspark-7d432af8f3c47973550ea253dae0c23cd2961bde.tar.gz
spark-7d432af8f3c47973550ea253dae0c23cd2961bde.tar.bz2
spark-7d432af8f3c47973550ea253dae0c23cd2961bde.zip
[SPARK-20043][ML] DecisionTreeModel: ImpurityCalculator builder fails for uppercase impurity type Gini
Fix bug: DecisionTreeModel can't recongnize Impurity "Gini" when loading TODO: + [x] add unit test + [x] fix the bug Author: 颜发才(Yan Facai) <facai.yan@gmail.com> Closes #17407 from facaiy/BUG/decision_tree_loader_failer_with_Gini_impurity.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala14
2 files changed, 15 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index a5bdc2c6d2..98a3021461 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -184,7 +184,7 @@ private[spark] object ImpurityCalculator {
* the given stats.
*/
def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = {
- impurity match {
+ impurity.toLowerCase match {
case "gini" => new GiniCalculator(stats)
case "entropy" => new EntropyCalculator(stats)
case "variance" => new VarianceCalculator(stats)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 10de50306a..964fcfbdd8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -385,6 +385,20 @@ class DecisionTreeClassifierSuite
testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0),
allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
}
+
+ test("SPARK-20043: " +
+ "ImpurityCalculator builder fails for uppercase impurity type Gini in model read/write") {
+ val rdd = TreeTests.getTreeReadWriteData(sc)
+ val data: DataFrame =
+ TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
+
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(2)
+ val model = dt.fit(data)
+
+ testDefaultReadWrite(model)
+ }
}
private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {