aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-08-03 10:46:34 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-03 10:46:34 -0700
commit69f5a7c934ac553ed52c00679b800bcffe83c1d6 (patch)
tree09e2a238d5ceab0231d2a5a1cfad662e24fcb8f7 /mllib/src/test
parent8be198c86935001907727fd16577231ff776125b (diff)
downloadspark-69f5a7c934ac553ed52c00679b800bcffe83c1d6.tar.gz
spark-69f5a7c934ac553ed52c00679b800bcffe83c1d6.tar.bz2
spark-69f5a7c934ac553ed52c00679b800bcffe83c1d6.zip
[SPARK-9528] [ML] Changed RandomForestClassifier to extend ProbabilisticClassifier
RandomForestClassifier now outputs rawPrediction based on tree probabilities, plus probability column computed from normalized rawPrediction. CC: holdenk Author: Joseph K. Bradley <joseph@databricks.com> Closes #7859 from jkbradley/rf-prob and squashes the following commits: 6c28f51 [Joseph K. Bradley] Changed RandomForestClassifier to extend ProbabilisticClassifier
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala36
1 files changed, 28 insertions, 8 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index dbb2577c62..edf848b21a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -26,6 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
@@ -121,6 +122,33 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
compareAPIs(rdd, rf2, categoricalFeatures, numClasses)
}
+ test("predictRaw and predictProbability") {
+ val rdd = orderedLabeledPoints5_20
+ val rf = new RandomForestClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(3)
+ .setNumTrees(3)
+ .setSeed(123)
+ val categoricalFeatures = Map.empty[Int, Int]
+ val numClasses = 2
+
+ val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
+ val model = rf.fit(df)
+
+ val predictions = model.transform(df)
+ .select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol)
+ .collect()
+
+ predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
+ assert(pred === rawPred.argmax,
+ s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
+ val sum = rawPred.toArray.sum
+ assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
+ "probability prediction mismatch")
+ assert(probPred.toArray.sum ~== 1.0 relTol 1E-5)
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
@@ -173,13 +201,5 @@ private object RandomForestClassifierSuite {
assert(newModel.hasParent)
assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent)
assert(newModel.numClasses == numClasses)
- val results = newModel.transform(newData)
- results.select("rawPrediction", "prediction").collect().foreach {
- case Row(raw: Vector, prediction: Double) => {
- assert(raw.size == numClasses)
- val predFromRaw = raw.toArray.zipWithIndex.maxBy(_._1)._2
- assert(predFromRaw == prediction)
- }
- }
}
}