aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2015-09-23 15:00:52 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-09-23 15:00:52 -0700
commit098be27ad53c485ee2fc7f5871c47f899020e87b (patch)
tree1e6fe63cc0bb8bd6088b4117bc1951fdd6c42507 /mllib/src/main/scala/org
parenta18208047f06a4244703c17023bb20cbe1f59d73 (diff)
downloadspark-098be27ad53c485ee2fc7f5871c47f899020e87b.tar.gz
spark-098be27ad53c485ee2fc7f5871c47f899020e87b.tar.bz2
spark-098be27ad53c485ee2fc7f5871c47f899020e87b.zip
[SPARK-9715] [ML] Store numFeatures in all ML PredictionModel types
All prediction models should store `numFeatures` indicating the number of features the model was trained on. Default value of -1 added for backwards compatibility. Author: sethah <seth.hendrickson16@gmail.com> Closes #8675 from sethah/SPARK-9715.
Diffstat (limited to 'mllib/src/main/scala/org')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Predictor.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala26
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala24
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala14
12 files changed, 84 insertions, 35 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index 19fe039b8f..e0dcd427fa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
@@ -145,6 +145,10 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
/** @group setParam */
def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M]
+ /** Returns the number of features the model was trained on. If unknown, returns -1 */
+ @Since("1.6.0")
+ def numFeatures: Int = -1
+
/**
* Returns the SQL DataType corresponding to the FeaturesType type parameter.
*
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index b8eb49f9bd..a6f6d463bf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -107,6 +107,7 @@ object DecisionTreeClassifier {
final class DecisionTreeClassificationModel private[ml] (
override val uid: String,
override val rootNode: Node,
+ override val numFeatures: Int,
override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
with DecisionTreeModel with Serializable {
@@ -118,8 +119,8 @@ final class DecisionTreeClassificationModel private[ml] (
* Construct a decision tree classification model.
* @param rootNode Root node of tree, with other nodes attached.
*/
- private[ml] def this(rootNode: Node, numClasses: Int) =
- this(Identifiable.randomUID("dtc"), rootNode, numClasses)
+ private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
+ this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses)
override protected def predict(features: Vector): Double = {
rootNode.predictImpl(features).prediction
@@ -141,7 +142,7 @@ final class DecisionTreeClassificationModel private[ml] (
}
override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
- copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra)
+ copyValues(new DecisionTreeClassificationModel(uid, rootNode, numFeatures, numClasses), extra)
.setParent(parent)
}
@@ -161,12 +162,14 @@ private[ml] object DecisionTreeClassificationModel {
def fromOld(
oldModel: OldDecisionTreeModel,
parent: DecisionTreeClassifier,
- categoricalFeatures: Map[Int, Int]): DecisionTreeClassificationModel = {
+ categoricalFeatures: Map[Int, Int],
+ numFeatures: Int = -1): DecisionTreeClassificationModel = {
require(oldModel.algo == OldAlgo.Classification,
s"Cannot convert non-classification DecisionTreeModel (old API) to" +
s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
- new DecisionTreeClassificationModel(uid, rootNode, -1)
+ // Can't infer number of features from old model, so default to -1
+ new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index ad8683648b..74aef94bf7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -33,7 +33,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{Row, DataFrame}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType
@@ -138,10 +138,11 @@ final class GBTClassifier(override val uid: String)
require(numClasses == 2,
s"GBTClassifier only supports binary classification but was given numClasses = $numClasses")
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
+ val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
val oldGBT = new OldGBT(boostingStrategy)
val oldModel = oldGBT.run(oldDataset)
- GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures)
+ GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures, numFeatures)
}
override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra)
@@ -164,10 +165,11 @@ object GBTClassifier {
* @param _treeWeights Weights for the decision trees in the ensemble.
*/
@Experimental
-final class GBTClassificationModel(
+final class GBTClassificationModel private[ml](
override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
- private val _treeWeights: Array[Double])
+ private val _treeWeights: Array[Double],
+ override val numFeatures: Int)
extends PredictionModel[Vector, GBTClassificationModel]
with TreeEnsembleModel with Serializable {
@@ -175,6 +177,14 @@ final class GBTClassificationModel(
require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
+ /**
+ * Construct a GBTClassificationModel
+ * @param _trees Decision trees in the ensemble.
+ * @param _treeWeights Weights for the decision trees in the ensemble.
+ */
+ def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) =
+ this(uid, _trees, _treeWeights, -1)
+
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
override def treeWeights: Array[Double] = _treeWeights
@@ -196,7 +206,8 @@ final class GBTClassificationModel(
}
override def copy(extra: ParamMap): GBTClassificationModel = {
- copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra).setParent(parent)
+ copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures),
+ extra).setParent(parent)
}
override def toString: String = {
@@ -215,7 +226,8 @@ private[ml] object GBTClassificationModel {
def fromOld(
oldModel: OldGBTModel,
parent: GBTClassifier,
- categoricalFeatures: Map[Int, Int]): GBTClassificationModel = {
+ categoricalFeatures: Map[Int, Int],
+ numFeatures: Int = -1): GBTClassificationModel = {
require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" +
s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).")
val newTrees = oldModel.trees.map { tree =>
@@ -223,6 +235,6 @@ private[ml] object GBTClassificationModel {
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
- new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights)
+ new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index bd96e8d000..c17a7b0c36 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -426,6 +426,8 @@ class LogisticRegressionModel private[ml] (
1.0 / (1.0 + math.exp(-m))
}
+ override val numFeatures: Int = weights.size
+
override val numClasses: Int = 2
private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index 5f60dea91f..cd7462596d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -181,6 +181,8 @@ class MultilayerPerceptronClassificationModel private[ml] (
extends PredictionModel[Vector, MultilayerPerceptronClassificationModel]
with Serializable {
+ override val numFeatures: Int = layers.head
+
private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights)
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 082ea1ffad..a14dcecbaf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -137,6 +137,8 @@ class NaiveBayesModel private[ml] (
throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
}
+ override val numFeatures: Int = theta.numCols
+
override val numClasses: Int = pi.size
private def multinomialCalculation(features: Vector) = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index a6ebee1bb1..bae329692a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -119,13 +119,12 @@ object RandomForestClassifier {
* features.
* @param _trees Decision trees in the ensemble.
* Warning: These have null parents.
- * @param numFeatures Number of features used by this model
*/
@Experimental
final class RandomForestClassificationModel private[ml] (
override val uid: String,
private val _trees: Array[DecisionTreeClassificationModel],
- val numFeatures: Int,
+ override val numFeatures: Int,
override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
with TreeEnsembleModel with Serializable {
@@ -226,7 +225,8 @@ private[ml] object RandomForestClassificationModel {
oldModel: OldRandomForestModel,
parent: RandomForestClassifier,
categoricalFeatures: Map[Int, Int],
- numClasses: Int): RandomForestClassificationModel = {
+ numClasses: Int,
+ numFeatures: Int = -1): RandomForestClassificationModel = {
require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" +
s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).")
val newTrees = oldModel.trees.map { tree =>
@@ -234,6 +234,6 @@ private[ml] object RandomForestClassificationModel {
DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc")
- new RandomForestClassificationModel(uid, newTrees, -1, numClasses)
+ new RandomForestClassificationModel(uid, newTrees, numFeatures, numClasses)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index d9a244bea2..88b79a4eb8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -96,7 +96,8 @@ object DecisionTreeRegressor {
@Experimental
final class DecisionTreeRegressionModel private[ml] (
override val uid: String,
- override val rootNode: Node)
+ override val rootNode: Node,
+ override val numFeatures: Int)
extends PredictionModel[Vector, DecisionTreeRegressionModel]
with DecisionTreeModel with Serializable {
@@ -107,14 +108,15 @@ final class DecisionTreeRegressionModel private[ml] (
* Construct a decision tree regression model.
* @param rootNode Root node of tree, with other nodes attached.
*/
- private[ml] def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode)
+ private[ml] def this(rootNode: Node, numFeatures: Int) =
+ this(Identifiable.randomUID("dtr"), rootNode, numFeatures)
override protected def predict(features: Vector): Double = {
rootNode.predictImpl(features).prediction
}
override def copy(extra: ParamMap): DecisionTreeRegressionModel = {
- copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra).setParent(parent)
+ copyValues(new DecisionTreeRegressionModel(uid, rootNode, numFeatures), extra).setParent(parent)
}
override def toString: String = {
@@ -133,12 +135,13 @@ private[ml] object DecisionTreeRegressionModel {
def fromOld(
oldModel: OldDecisionTreeModel,
parent: DecisionTreeRegressor,
- categoricalFeatures: Map[Int, Int]): DecisionTreeRegressionModel = {
+ categoricalFeatures: Map[Int, Int],
+ numFeatures: Int = -1): DecisionTreeRegressionModel = {
require(oldModel.algo == OldAlgo.Regression,
s"Cannot convert non-regression DecisionTreeModel (old API) to" +
s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}")
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr")
- new DecisionTreeRegressionModel(uid, rootNode)
+ new DecisionTreeRegressionModel(uid, rootNode, numFeatures)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index d841ecb9e5..65b5b3e072 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -128,10 +128,11 @@ final class GBTRegressor(override val uid: String)
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
+ val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
val oldGBT = new OldGBT(boostingStrategy)
val oldModel = oldGBT.run(oldDataset)
- GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures)
+ GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures, numFeatures)
}
override def copy(extra: ParamMap): GBTRegressor = defaultCopy(extra)
@@ -154,10 +155,11 @@ object GBTRegressor {
* @param _treeWeights Weights for the decision trees in the ensemble.
*/
@Experimental
-final class GBTRegressionModel(
+final class GBTRegressionModel private[ml](
override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
- private val _treeWeights: Array[Double])
+ private val _treeWeights: Array[Double],
+ override val numFeatures: Int)
extends PredictionModel[Vector, GBTRegressionModel]
with TreeEnsembleModel with Serializable {
@@ -165,6 +167,14 @@ final class GBTRegressionModel(
require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" +
s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
+ /**
+ * Construct a GBTRegressionModel
+ * @param _trees Decision trees in the ensemble.
+ * @param _treeWeights Weights for the decision trees in the ensemble.
+ */
+ def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) =
+ this(uid, _trees, _treeWeights, -1)
+
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
override def treeWeights: Array[Double] = _treeWeights
@@ -185,7 +195,8 @@ final class GBTRegressionModel(
}
override def copy(extra: ParamMap): GBTRegressionModel = {
- copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra).setParent(parent)
+ copyValues(new GBTRegressionModel(uid, _trees, _treeWeights, numFeatures),
+ extra).setParent(parent)
}
override def toString: String = {
@@ -204,7 +215,8 @@ private[ml] object GBTRegressionModel {
def fromOld(
oldModel: OldGBTModel,
parent: GBTRegressor,
- categoricalFeatures: Map[Int, Int]): GBTRegressionModel = {
+ categoricalFeatures: Map[Int, Int],
+ numFeatures: Int = -1): GBTRegressionModel = {
require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" +
s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).")
val newTrees = oldModel.trees.map { tree =>
@@ -212,6 +224,6 @@ private[ml] object GBTRegressionModel {
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr")
- new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights)
+ new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 78a67c5fda..a77e702141 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -293,6 +293,8 @@ class LinearRegressionModel private[ml] (
private var trainingSummary: Option[LinearRegressionTrainingSummary] = None
+ override val numFeatures: Int = weights.size
+
/**
* Gets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is
* thrown if `trainingSummary == None`.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index ddb7214416..64fc17247c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -115,7 +115,7 @@ object RandomForestRegressor {
final class RandomForestRegressionModel private[ml] (
override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
- val numFeatures: Int)
+ override val numFeatures: Int)
extends PredictionModel[Vector, RandomForestRegressionModel]
with TreeEnsembleModel with Serializable {
@@ -187,13 +187,14 @@ private[ml] object RandomForestRegressionModel {
def fromOld(
oldModel: OldRandomForestModel,
parent: RandomForestRegressor,
- categoricalFeatures: Map[Int, Int]): RandomForestRegressionModel = {
+ categoricalFeatures: Map[Int, Int],
+ numFeatures: Int = -1): RandomForestRegressionModel = {
require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" +
s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).")
val newTrees = oldModel.trees.map { tree =>
// parent for each tree is null since there is no good way to set this.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
- new RandomForestRegressionModel(parent.uid, newTrees, -1)
+ new RandomForestRegressionModel(parent.uid, newTrees, numFeatures)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 4ac51a4754..c494556085 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -179,22 +179,28 @@ private[ml] object RandomForest extends Logging {
}
}
+ val numFeatures = metadata.numFeatures
+
parentUID match {
case Some(uid) =>
if (strategy.algo == OldAlgo.Classification) {
topNodes.map { rootNode =>
- new DecisionTreeClassificationModel(uid, rootNode.toNode, strategy.getNumClasses)
+ new DecisionTreeClassificationModel(uid, rootNode.toNode, numFeatures,
+ strategy.getNumClasses)
}
} else {
- topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode))
+ topNodes.map { rootNode =>
+ new DecisionTreeRegressionModel(uid, rootNode.toNode, numFeatures)
+ }
}
case None =>
if (strategy.algo == OldAlgo.Classification) {
topNodes.map { rootNode =>
- new DecisionTreeClassificationModel(rootNode.toNode, strategy.getNumClasses)
+ new DecisionTreeClassificationModel(rootNode.toNode, numFeatures,
+ strategy.getNumClasses)
}
} else {
- topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode))
+ topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode, numFeatures))
}
}
}