aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-08-03 10:46:34 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-03 10:46:34 -0700
commit69f5a7c934ac553ed52c00679b800bcffe83c1d6 (patch)
tree09e2a238d5ceab0231d2a5a1cfad662e24fcb8f7 /mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
parent8be198c86935001907727fd16577231ff776125b (diff)
downloadspark-69f5a7c934ac553ed52c00679b800bcffe83c1d6.tar.gz
spark-69f5a7c934ac553ed52c00679b800bcffe83c1d6.tar.bz2
spark-69f5a7c934ac553ed52c00679b800bcffe83c1d6.zip
[SPARK-9528] [ML] Changed RandomForestClassifier to extend ProbabilisticClassifier
RandomForestClassifier now outputs rawPrediction based on tree probabilities, plus probability column computed from normalized rawPrediction. CC: holdenk Author: Joseph K. Bradley <joseph@databricks.com> Closes #7859 from jkbradley/rf-prob and squashes the following commits: 6c28f51 [Joseph K. Bradley] Changed RandomForestClassifier to extend ProbabilisticClassifier
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala37
1 files changed, 26 insertions, 11 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 0c7eb4a662..56e80cc8fe 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -17,22 +17,19 @@
package org.apache.spark.ml.classification
-import scala.collection.mutable
-
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.tree.impl.RandomForest
-import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.DoubleType
+
/**
* :: Experimental ::
@@ -43,7 +40,7 @@ import org.apache.spark.sql.types.DoubleType
*/
@Experimental
final class RandomForestClassifier(override val uid: String)
- extends Classifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
+ extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
with RandomForestParams with TreeClassifierParams {
def this() = this(Identifiable.randomUID("rfc"))
@@ -127,7 +124,7 @@ final class RandomForestClassificationModel private[ml] (
override val uid: String,
private val _trees: Array[DecisionTreeClassificationModel],
override val numClasses: Int)
- extends ClassificationModel[Vector, RandomForestClassificationModel]
+ extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
with TreeEnsembleModel with Serializable {
require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")
@@ -157,15 +154,33 @@ final class RandomForestClassificationModel private[ml] (
override protected def predictRaw(features: Vector): Vector = {
// TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
// Classifies using majority votes.
- // Ignore the weights since all are 1.0 for now.
- val votes = new Array[Double](numClasses)
+ // Ignore the tree weights since all are 1.0 for now.
+ val votes = Array.fill[Double](numClasses)(0.0)
_trees.view.foreach { tree =>
- val prediction = tree.rootNode.predictImpl(features).prediction.toInt
- votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight
+ val classCounts: Array[Double] = tree.rootNode.predictImpl(features).impurityStats.stats
+ val total = classCounts.sum
+ if (total != 0) {
+ var i = 0
+ while (i < numClasses) {
+ votes(i) += classCounts(i) / total
+ i += 1
+ }
+ }
}
Vectors.dense(votes)
}
+ override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+ rawPrediction match {
+ case dv: DenseVector =>
+ ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(dv)
+ dv
+ case sv: SparseVector =>
+ throw new RuntimeException("Unexpected error in RandomForestClassificationModel:" +
+ " raw2probabilityInPlace encountered SparseVector")
+ }
+ }
+
override def copy(extra: ParamMap): RandomForestClassificationModel = {
copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra)
}