diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2015-08-03 10:46:34 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-08-03 10:46:34 -0700 |
commit | 69f5a7c934ac553ed52c00679b800bcffe83c1d6 (patch) | |
tree | 09e2a238d5ceab0231d2a5a1cfad662e24fcb8f7 /mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala | |
parent | 8be198c86935001907727fd16577231ff776125b (diff) | |
download | spark-69f5a7c934ac553ed52c00679b800bcffe83c1d6.tar.gz spark-69f5a7c934ac553ed52c00679b800bcffe83c1d6.tar.bz2 spark-69f5a7c934ac553ed52c00679b800bcffe83c1d6.zip |
[SPARK-9528] [ML] Changed RandomForestClassifier to extend ProbabilisticClassifier
RandomForestClassifier now outputs rawPrediction based on tree probabilities, plus probability column computed from normalized rawPrediction.
CC: holdenk
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #7859 from jkbradley/rf-prob and squashes the following commits:
6c28f51 [Joseph K. Bradley] Changed RandomForestClassifier to extend ProbabilisticClassifier
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala | 37 |
1 files changed, 26 insertions, 11 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 0c7eb4a662..56e80cc8fe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -17,22 +17,19 @@ package org.apache.spark.ml.classification -import scala.collection.mutable - import org.apache.spark.annotation.Experimental import org.apache.spark.ml.tree.impl.RandomForest -import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DoubleType + /** * :: Experimental :: @@ -43,7 +40,7 @@ import org.apache.spark.sql.types.DoubleType */ @Experimental final class RandomForestClassifier(override val uid: String) - extends Classifier[Vector, RandomForestClassifier, RandomForestClassificationModel] + extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel] with RandomForestParams with TreeClassifierParams { def this() = this(Identifiable.randomUID("rfc")) @@ -127,7 +124,7 @@ final class RandomForestClassificationModel private[ml] ( override val uid: String, private val _trees: Array[DecisionTreeClassificationModel], override val numClasses: Int) - extends ClassificationModel[Vector, RandomForestClassificationModel] + extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") @@ -157,15 +154,33 @@ final class RandomForestClassificationModel private[ml] ( override protected def predictRaw(features: Vector): Vector = { // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 // Classifies using majority votes. - // Ignore the weights since all are 1.0 for now. - val votes = new Array[Double](numClasses) + // Ignore the tree weights since all are 1.0 for now. + val votes = Array.fill[Double](numClasses)(0.0) _trees.view.foreach { tree => - val prediction = tree.rootNode.predictImpl(features).prediction.toInt - votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight + val classCounts: Array[Double] = tree.rootNode.predictImpl(features).impurityStats.stats + val total = classCounts.sum + if (total != 0) { + var i = 0 + while (i < numClasses) { + votes(i) += classCounts(i) / total + i += 1 + } + } } Vectors.dense(votes) } + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction match { + case dv: DenseVector => + ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(dv) + dv + case sv: SparseVector => + throw new RuntimeException("Unexpected error in RandomForestClassificationModel:" + + " raw2probabilityInPlace encountered SparseVector") + } + } + override def copy(extra: ParamMap): RandomForestClassificationModel = { copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra) } |