diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2015-08-03 10:46:34 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-08-03 10:46:34 -0700 |
commit | 69f5a7c934ac553ed52c00679b800bcffe83c1d6 (patch) | |
tree | 09e2a238d5ceab0231d2a5a1cfad662e24fcb8f7 /mllib/src/test | |
parent | 8be198c86935001907727fd16577231ff776125b (diff) | |
download | spark-69f5a7c934ac553ed52c00679b800bcffe83c1d6.tar.gz spark-69f5a7c934ac553ed52c00679b800bcffe83c1d6.tar.bz2 spark-69f5a7c934ac553ed52c00679b800bcffe83c1d6.zip |
[SPARK-9528] [ML] Changed RandomForestClassifier to extend ProbabilisticClassifier
RandomForestClassifier now outputs rawPrediction based on tree probabilities, plus probability column computed from normalized rawPrediction.
CC: holdenk
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #7859 from jkbradley/rf-prob and squashes the following commits:
6c28f51 [Joseph K. Bradley] Changed RandomForestClassifier to extend ProbabilisticClassifier
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala | 36 |
1 files changed, 28 insertions, 8 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index dbb2577c62..edf848b21a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} @@ -121,6 +122,33 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte compareAPIs(rdd, rf2, categoricalFeatures, numClasses) } + test("predictRaw and predictProbability") { + val rdd = orderedLabeledPoints5_20 + val rf = new RandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(3) + .setNumTrees(3) + .setSeed(123) + val categoricalFeatures = Map.empty[Int, Int] + val numClasses = 2 + + val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val model = rf.fit(df) + + val predictions = model.transform(df) + .select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol) + .collect() + + predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) => + assert(pred === rawPred.argmax, + s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") + val sum = rawPred.toArray.sum + assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred, + "probability prediction mismatch") + assert(probPred.toArray.sum ~== 1.0 relTol 1E-5) + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// @@ -173,13 +201,5 @@ private object RandomForestClassifierSuite { assert(newModel.hasParent) assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent) assert(newModel.numClasses == numClasses) - val results = newModel.transform(newData) - results.select("rawPrediction", "prediction").collect().foreach { - case Row(raw: Vector, prediction: Double) => { - assert(raw.size == numClasses) - val predFromRaw = raw.toArray.zipWithIndex.maxBy(_._1)._2 - assert(predFromRaw == prediction) - } - } } } |