aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-07-31 11:56:52 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-31 11:56:52 -0700
commite8bdcdeabb2df139a656f86686cdb53c891b1f4b (patch)
treed6ccab74f50d58e7b18a786ce66dcd8f5fe30f60 /mllib/src/test
parent4011a947154d97a9ffb5a71f077481a12534d36b (diff)
downloadspark-e8bdcdeabb2df139a656f86686cdb53c891b1f4b.tar.gz
spark-e8bdcdeabb2df139a656f86686cdb53c891b1f4b.tar.bz2
spark-e8bdcdeabb2df139a656f86686cdb53c891b1f4b.zip
[SPARK-6885] [ML] decision tree support predict class probabilities
Decision tree support predict class probabilities. Implement the prediction probabilities function referred the old DecisionTree API and the [sklean API](https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/tree.py#L593). I make the DecisionTreeClassificationModel inherit from ProbabilisticClassificationModel, make the predictRaw to return the raw counts vector and make raw2probabilityInPlace/predictProbability return the probabilities for each prediction. Author: Yanbo Liang <ybliang8@gmail.com> Closes #7694 from yanboliang/spark-6885 and squashes the following commits: 08d5b7f [Yanbo Liang] fix ImpurityStats null parameters and raw2probabilityInPlace sum = 0 issue 2174278 [Yanbo Liang] solve merge conflicts 7e90ba8 [Yanbo Liang] fix typos 33ae183 [Yanbo Liang] fix annotation ff043d3 [Yanbo Liang] raw2probabilityInPlace should operate in-place c32d6ce [Yanbo Liang] optimize calculateImpurityStats function again 6167fb0 [Yanbo Liang] optimize calculateImpurityStats function fbbe2ec [Yanbo Liang] eliminate duplicated struct and code beb1634 [Yanbo Liang] try to eliminate impurityStats for each LearningNode 99e8943 [Yanbo Liang] code optimization 5ec3323 [Yanbo Liang] implement InformationGainAndImpurityStats 227c91b [Yanbo Liang] refactor LearningNode to store ImpurityCalculator d746ffc [Yanbo Liang] decision tree support predict class probabilities
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala30
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala2
3 files changed, 30 insertions, 4 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 73b4805c4c..c7bbf1ce07 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -21,12 +21,13 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode
-import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.Row
class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -57,7 +58,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
test("params") {
ParamsSuite.checkParams(new DecisionTreeClassifier)
- val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))
+ val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)
ParamsSuite.checkParams(model)
}
@@ -231,6 +232,31 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
}
+ test("predictRaw and predictProbability") {
+ val rdd = continuousDataPointsForMulticlassRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 3)
+ val numClasses = 3
+
+ val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
+ val newTree = dt.fit(newData)
+
+ val predictions = newTree.transform(newData)
+ .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol)
+ .collect()
+
+ predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
+ assert(pred === rawPred.argmax,
+ s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
+ val sum = rawPred.toArray.sum
+ assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
+ "probability prediction mismatch")
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index a7bc77965f..d4b5896c12 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -58,7 +58,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
test("params") {
ParamsSuite.checkParams(new GBTClassifier)
val model = new GBTClassificationModel("gbtc",
- Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0))),
+ Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null))),
Array(1.0))
ParamsSuite.checkParams(model)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index ab711c8e4b..dbb2577c62 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -66,7 +66,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
test("params") {
ParamsSuite.checkParams(new RandomForestClassifier)
val model = new RandomForestClassificationModel("rfc",
- Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))), 2)
+ Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2)
ParamsSuite.checkParams(model)
}