diff options
author | Liang-Chi Hsieh <viirya@gmail.com> | 2016-02-09 17:10:55 -0800 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-02-09 17:10:55 -0800 |
commit | 9267bc68fab65c6a798e065a1dbe0f5171df3077 (patch) | |
tree | afbf5313cbc324c134b2c9dc20ed56860bf7e427 /mllib/src/test | |
parent | 0e5ebac3c1f1ff58f938be59c7c9e604977d269c (diff) | |
download | spark-9267bc68fab65c6a798e065a1dbe0f5171df3077.tar.gz spark-9267bc68fab65c6a798e065a1dbe0f5171df3077.tar.bz2 spark-9267bc68fab65c6a798e065a1dbe0f5171df3077.zip |
[SPARK-10524][ML] Use the soft prediction to order categories' bins
JIRA: https://issues.apache.org/jira/browse/SPARK-10524
Currently we use the hard prediction (`ImpurityCalculator.predict`) to order categories' bins. But we should use the soft prediction.
Author: Liang-Chi Hsieh <viirya@gmail.com>
Author: Liang-Chi Hsieh <viirya@appier.com>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #8734 from viirya/dt-soft-centroids.
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala | 36 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala | 30 |
2 files changed, 65 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index fda2711fed..baf6b90839 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode} import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint @@ -275,6 +275,40 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte val model = dt.fit(df) } + test("Use soft prediction for binary classification with ordered categorical features") { + // The following dataset is set up such that the best split is {1} vs. {0, 2}. + // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(2.0))) + val data = sc.parallelize(arr) + val df = TreeTests.setMetadata(data, Map(0 -> 3), 2) + + // Must set maxBins s.t. the feature will be treated as an ordered categorical feature. + val dt = new DecisionTreeClassifier() + .setImpurity("gini") + .setMaxDepth(1) + .setMaxBins(3) + val model = dt.fit(df) + model.rootNode match { + case n: InternalNode => + n.split match { + case s: CategoricalSplit => + assert(s.leftCategories === Array(1.0)) + } + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index a9c935bd42..dca8ea815a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, Tree import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model._ import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils @@ -337,6 +338,35 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.rightNode.get.impurity === 0.0) } + test("Use soft prediction for binary classification with ordered categorical features") { + // The following dataset is set up such that the best split is {1} vs. {0, 2}. + // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(2.0))) + val input = sc.parallelize(arr) + + // Must set maxBins s.t. the feature will be treated as an ordered categorical feature. + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3) + + val model = new DecisionTree(strategy).run(input) + model.topNode.split.get match { + case Split(_, _, _, categories: List[Double]) => + assert(categories === List(1.0)) + } + } + test("Second level node building with vs. without groups") { val arr = DecisionTreeSuite.generateOrderedLabeledPoints() assert(arr.length === 1000) |