aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2016-02-09 17:10:55 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-02-09 17:10:55 -0800
commit9267bc68fab65c6a798e065a1dbe0f5171df3077 (patch)
treeafbf5313cbc324c134b2c9dc20ed56860bf7e427 /mllib/src/test/scala/org/apache
parent0e5ebac3c1f1ff58f938be59c7c9e604977d269c (diff)
downloadspark-9267bc68fab65c6a798e065a1dbe0f5171df3077.tar.gz
spark-9267bc68fab65c6a798e065a1dbe0f5171df3077.tar.bz2
spark-9267bc68fab65c6a798e065a1dbe0f5171df3077.zip
[SPARK-10524][ML] Use the soft prediction to order categories' bins
JIRA: https://issues.apache.org/jira/browse/SPARK-10524 Currently we use the hard prediction (`ImpurityCalculator.predict`) to order categories' bins. But we should use the soft prediction. Author: Liang-Chi Hsieh <viirya@gmail.com> Author: Liang-Chi Hsieh <viirya@appier.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #8734 from viirya/dt-soft-centroids.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala36
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala30
2 files changed, 65 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index fda2711fed..baf6b90839 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.tree.LeafNode
+import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode}
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
@@ -275,6 +275,40 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
val model = dt.fit(df)
}
+ test("Use soft prediction for binary classification with ordered categorical features") {
+ // The following dataset is set up such that the best split is {1} vs. {0, 2}.
+ // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0)),
+ LabeledPoint(1.0, Vectors.dense(2.0)))
+ val data = sc.parallelize(arr)
+ val df = TreeTests.setMetadata(data, Map(0 -> 3), 2)
+
+ // Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("gini")
+ .setMaxDepth(1)
+ .setMaxBins(3)
+ val model = dt.fit(df)
+ model.rootNode match {
+ case n: InternalNode =>
+ n.split match {
+ case s: CategoricalSplit =>
+ assert(s.leftCategories === Array(1.0))
+ }
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index a9c935bd42..dca8ea815a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, Tree
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model._
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
@@ -337,6 +338,35 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(topNode.rightNode.get.impurity === 0.0)
}
+ test("Use soft prediction for binary classification with ordered categorical features") {
+ // The following dataset is set up such that the best split is {1} vs. {0, 2}.
+ // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0)),
+ LabeledPoint(1.0, Vectors.dense(2.0)))
+ val input = sc.parallelize(arr)
+
+ // Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
+ numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
+
+ val model = new DecisionTree(strategy).run(input)
+ model.topNode.split.get match {
+ case Split(_, _, _, categories: List[Double]) =>
+ assert(categories === List(1.0))
+ }
+ }
+
test("Second level node building with vs. without groups") {
val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
assert(arr.length === 1000)