aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala30
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala18
-rw-r--r--python/pyspark/ml/classification.py4
3 files changed, 32 insertions, 20 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index fc0693f67c..bc19bd6df8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -25,7 +25,7 @@ import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
@@ -43,7 +43,7 @@ import org.apache.spark.sql.types.DoubleType
*/
@Experimental
final class RandomForestClassifier(override val uid: String)
- extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel]
+ extends Classifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
with RandomForestParams with TreeClassifierParams {
def this() = this(Identifiable.randomUID("rfc"))
@@ -98,7 +98,7 @@ final class RandomForestClassifier(override val uid: String)
val trees =
RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
.map(_.asInstanceOf[DecisionTreeClassificationModel])
- new RandomForestClassificationModel(trees)
+ new RandomForestClassificationModel(trees, numClasses)
}
override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
@@ -125,8 +125,9 @@ object RandomForestClassifier {
@Experimental
final class RandomForestClassificationModel private[ml] (
override val uid: String,
- private val _trees: Array[DecisionTreeClassificationModel])
- extends PredictionModel[Vector, RandomForestClassificationModel]
+ private val _trees: Array[DecisionTreeClassificationModel],
+ override val numClasses: Int)
+ extends ClassificationModel[Vector, RandomForestClassificationModel]
with TreeEnsembleModel with Serializable {
require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")
@@ -135,8 +136,8 @@ final class RandomForestClassificationModel private[ml] (
* Construct a random forest classification model, with all trees weighted equally.
* @param trees Component trees
*/
- def this(trees: Array[DecisionTreeClassificationModel]) =
- this(Identifiable.randomUID("rfc"), trees)
+ def this(trees: Array[DecisionTreeClassificationModel], numClasses: Int) =
+ this(Identifiable.randomUID("rfc"), trees, numClasses)
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
@@ -153,20 +154,20 @@ final class RandomForestClassificationModel private[ml] (
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
- override protected def predict(features: Vector): Double = {
+ override protected def predictRaw(features: Vector): Vector = {
// TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
// Classifies using majority votes.
// Ignore the weights since all are 1.0 for now.
- val votes = mutable.Map.empty[Int, Double]
+ val votes = new Array[Double](numClasses)
_trees.view.foreach { tree =>
val prediction = tree.rootNode.predict(features).toInt
- votes(prediction) = votes.getOrElse(prediction, 0.0) + 1.0 // 1.0 = weight
+ votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight
}
- votes.maxBy(_._2)._1
+ Vectors.dense(votes)
}
override def copy(extra: ParamMap): RandomForestClassificationModel = {
- copyValues(new RandomForestClassificationModel(uid, _trees), extra)
+ copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra)
}
override def toString: String = {
@@ -185,7 +186,8 @@ private[ml] object RandomForestClassificationModel {
def fromOld(
oldModel: OldRandomForestModel,
parent: RandomForestClassifier,
- categoricalFeatures: Map[Int, Int]): RandomForestClassificationModel = {
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int): RandomForestClassificationModel = {
require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" +
s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).")
val newTrees = oldModel.trees.map { tree =>
@@ -193,6 +195,6 @@ private[ml] object RandomForestClassificationModel {
DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc")
- new RandomForestClassificationModel(uid, newTrees)
+ new RandomForestClassificationModel(uid, newTrees, numClasses)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 1b6b69c7dc..ab711c8e4b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -21,13 +21,13 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode
-import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row}
/**
* Test suite for [[RandomForestClassifier]].
@@ -66,7 +66,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
test("params") {
ParamsSuite.checkParams(new RandomForestClassifier)
val model = new RandomForestClassificationModel("rfc",
- Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))))
+ Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))), 2)
ParamsSuite.checkParams(model)
}
@@ -167,9 +167,19 @@ private object RandomForestClassifierSuite {
val newModel = rf.fit(newData)
// Use parent from newTree since this is not checked anyways.
val oldModelAsNew = RandomForestClassificationModel.fromOld(
- oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures)
+ oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures,
+ numClasses)
TreeTests.checkEqual(oldModelAsNew, newModel)
assert(newModel.hasParent)
assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent)
+ assert(newModel.numClasses == numClasses)
+ val results = newModel.transform(newData)
+ results.select("rawPrediction", "prediction").collect().foreach {
+ case Row(raw: Vector, prediction: Double) => {
+ assert(raw.size == numClasses)
+ val predFromRaw = raw.toArray.zipWithIndex.maxBy(_._1)._2
+ assert(predFromRaw == prediction)
+ }
+ }
}
}
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 89117e4928..5a82bc286d 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -299,9 +299,9 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
>>> si_model = stringIndexer.fit(df)
>>> td = si_model.transform(df)
- >>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed", seed=42)
+ >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42)
>>> model = rf.fit(td)
- >>> allclose(model.treeWeights, [1.0, 1.0])
+ >>> allclose(model.treeWeights, [1.0, 1.0, 1.0])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction