diff options
author | sethah <seth.hendrickson16@gmail.com> | 2015-09-23 15:00:52 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-09-23 15:00:52 -0700 |
commit | 098be27ad53c485ee2fc7f5871c47f899020e87b (patch) | |
tree | 1e6fe63cc0bb8bd6088b4117bc1951fdd6c42507 /mllib/src/main/scala/org | |
parent | a18208047f06a4244703c17023bb20cbe1f59d73 (diff) | |
download | spark-098be27ad53c485ee2fc7f5871c47f899020e87b.tar.gz spark-098be27ad53c485ee2fc7f5871c47f899020e87b.tar.bz2 spark-098be27ad53c485ee2fc7f5871c47f899020e87b.zip |
[SPARK-9715] [ML] Store numFeatures in all ML PredictionModel types
All prediction models should store `numFeatures` indicating the number of features the model was trained on. Default value of -1 added for backwards compatibility.
Author: sethah <seth.hendrickson16@gmail.com>
Closes #8675 from sethah/SPARK-9715.
Diffstat (limited to 'mllib/src/main/scala/org')
12 files changed, 84 insertions, 35 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 19fe039b8f..e0dcd427fa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils @@ -145,6 +145,10 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, /** @group setParam */ def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M] + /** Returns the number of features the model was trained on. If unknown, returns -1 */ + @Since("1.6.0") + def numFeatures: Int = -1 + /** * Returns the SQL DataType corresponding to the FeaturesType type parameter. * diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index b8eb49f9bd..a6f6d463bf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -107,6 +107,7 @@ object DecisionTreeClassifier { final class DecisionTreeClassificationModel private[ml] ( override val uid: String, override val rootNode: Node, + override val numFeatures: Int, override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] with DecisionTreeModel with Serializable { @@ -118,8 +119,8 @@ final class DecisionTreeClassificationModel private[ml] ( * Construct a decision tree classification model. * @param rootNode Root node of tree, with other nodes attached. */ - private[ml] def this(rootNode: Node, numClasses: Int) = - this(Identifiable.randomUID("dtc"), rootNode, numClasses) + private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) = + this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses) override protected def predict(features: Vector): Double = { rootNode.predictImpl(features).prediction @@ -141,7 +142,7 @@ final class DecisionTreeClassificationModel private[ml] ( } override def copy(extra: ParamMap): DecisionTreeClassificationModel = { - copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra) + copyValues(new DecisionTreeClassificationModel(uid, rootNode, numFeatures, numClasses), extra) .setParent(parent) } @@ -161,12 +162,14 @@ private[ml] object DecisionTreeClassificationModel { def fromOld( oldModel: OldDecisionTreeModel, parent: DecisionTreeClassifier, - categoricalFeatures: Map[Int, Int]): DecisionTreeClassificationModel = { + categoricalFeatures: Map[Int, Int], + numFeatures: Int = -1): DecisionTreeClassificationModel = { require(oldModel.algo == OldAlgo.Classification, s"Cannot convert non-classification DecisionTreeModel (old API) to" + s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc") - new DecisionTreeClassificationModel(uid, rootNode, -1) + // Can't infer number of features from old model, so default to -1 + new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index ad8683648b..74aef94bf7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{Row, DataFrame} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DoubleType @@ -138,10 +138,11 @@ final class GBTClassifier(override val uid: String) require(numClasses == 2, s"GBTClassifier only supports binary classification but was given numClasses = $numClasses") val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) val oldGBT = new OldGBT(boostingStrategy) val oldModel = oldGBT.run(oldDataset) - GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures) + GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures, numFeatures) } override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra) @@ -164,10 +165,11 @@ object GBTClassifier { * @param _treeWeights Weights for the decision trees in the ensemble. */ @Experimental -final class GBTClassificationModel( +final class GBTClassificationModel private[ml]( override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], - private val _treeWeights: Array[Double]) + private val _treeWeights: Array[Double], + override val numFeatures: Int) extends PredictionModel[Vector, GBTClassificationModel] with TreeEnsembleModel with Serializable { @@ -175,6 +177,14 @@ final class GBTClassificationModel( require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" + s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") + /** + * Construct a GBTClassificationModel + * @param _trees Decision trees in the ensemble. + * @param _treeWeights Weights for the decision trees in the ensemble. + */ + def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) = + this(uid, _trees, _treeWeights, -1) + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] override def treeWeights: Array[Double] = _treeWeights @@ -196,7 +206,8 @@ final class GBTClassificationModel( } override def copy(extra: ParamMap): GBTClassificationModel = { - copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra).setParent(parent) + copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures), + extra).setParent(parent) } override def toString: String = { @@ -215,7 +226,8 @@ private[ml] object GBTClassificationModel { def fromOld( oldModel: OldGBTModel, parent: GBTClassifier, - categoricalFeatures: Map[Int, Int]): GBTClassificationModel = { + categoricalFeatures: Map[Int, Int], + numFeatures: Int = -1): GBTClassificationModel = { require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => @@ -223,6 +235,6 @@ private[ml] object GBTClassificationModel { DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc") - new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights) + new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index bd96e8d000..c17a7b0c36 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -426,6 +426,8 @@ class LogisticRegressionModel private[ml] ( 1.0 / (1.0 + math.exp(-m)) } + override val numFeatures: Int = weights.size + override val numClasses: Int = 2 private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 5f60dea91f..cd7462596d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -181,6 +181,8 @@ class MultilayerPerceptronClassificationModel private[ml] ( extends PredictionModel[Vector, MultilayerPerceptronClassificationModel] with Serializable { + override val numFeatures: Int = layers.head + private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 082ea1ffad..a14dcecbaf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -137,6 +137,8 @@ class NaiveBayesModel private[ml] ( throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") } + override val numFeatures: Int = theta.numCols + override val numClasses: Int = pi.size private def multinomialCalculation(features: Vector) = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index a6ebee1bb1..bae329692a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -119,13 +119,12 @@ object RandomForestClassifier { * features. * @param _trees Decision trees in the ensemble. * Warning: These have null parents. - * @param numFeatures Number of features used by this model */ @Experimental final class RandomForestClassificationModel private[ml] ( override val uid: String, private val _trees: Array[DecisionTreeClassificationModel], - val numFeatures: Int, + override val numFeatures: Int, override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { @@ -226,7 +225,8 @@ private[ml] object RandomForestClassificationModel { oldModel: OldRandomForestModel, parent: RandomForestClassifier, categoricalFeatures: Map[Int, Int], - numClasses: Int): RandomForestClassificationModel = { + numClasses: Int, + numFeatures: Int = -1): RandomForestClassificationModel = { require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => @@ -234,6 +234,6 @@ private[ml] object RandomForestClassificationModel { DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc") - new RandomForestClassificationModel(uid, newTrees, -1, numClasses) + new RandomForestClassificationModel(uid, newTrees, numFeatures, numClasses) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index d9a244bea2..88b79a4eb8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -96,7 +96,8 @@ object DecisionTreeRegressor { @Experimental final class DecisionTreeRegressionModel private[ml] ( override val uid: String, - override val rootNode: Node) + override val rootNode: Node, + override val numFeatures: Int) extends PredictionModel[Vector, DecisionTreeRegressionModel] with DecisionTreeModel with Serializable { @@ -107,14 +108,15 @@ final class DecisionTreeRegressionModel private[ml] ( * Construct a decision tree regression model. * @param rootNode Root node of tree, with other nodes attached. */ - private[ml] def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode) + private[ml] def this(rootNode: Node, numFeatures: Int) = + this(Identifiable.randomUID("dtr"), rootNode, numFeatures) override protected def predict(features: Vector): Double = { rootNode.predictImpl(features).prediction } override def copy(extra: ParamMap): DecisionTreeRegressionModel = { - copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra).setParent(parent) + copyValues(new DecisionTreeRegressionModel(uid, rootNode, numFeatures), extra).setParent(parent) } override def toString: String = { @@ -133,12 +135,13 @@ private[ml] object DecisionTreeRegressionModel { def fromOld( oldModel: OldDecisionTreeModel, parent: DecisionTreeRegressor, - categoricalFeatures: Map[Int, Int]): DecisionTreeRegressionModel = { + categoricalFeatures: Map[Int, Int], + numFeatures: Int = -1): DecisionTreeRegressionModel = { require(oldModel.algo == OldAlgo.Regression, s"Cannot convert non-regression DecisionTreeModel (old API) to" + s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}") val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr") - new DecisionTreeRegressionModel(uid, rootNode) + new DecisionTreeRegressionModel(uid, rootNode, numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index d841ecb9e5..65b5b3e072 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -128,10 +128,11 @@ final class GBTRegressor(override val uid: String) val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) val oldGBT = new OldGBT(boostingStrategy) val oldModel = oldGBT.run(oldDataset) - GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures) + GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures, numFeatures) } override def copy(extra: ParamMap): GBTRegressor = defaultCopy(extra) @@ -154,10 +155,11 @@ object GBTRegressor { * @param _treeWeights Weights for the decision trees in the ensemble. */ @Experimental -final class GBTRegressionModel( +final class GBTRegressionModel private[ml]( override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], - private val _treeWeights: Array[Double]) + private val _treeWeights: Array[Double], + override val numFeatures: Int) extends PredictionModel[Vector, GBTRegressionModel] with TreeEnsembleModel with Serializable { @@ -165,6 +167,14 @@ final class GBTRegressionModel( require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" + s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") + /** + * Construct a GBTRegressionModel + * @param _trees Decision trees in the ensemble. + * @param _treeWeights Weights for the decision trees in the ensemble. + */ + def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) = + this(uid, _trees, _treeWeights, -1) + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] override def treeWeights: Array[Double] = _treeWeights @@ -185,7 +195,8 @@ final class GBTRegressionModel( } override def copy(extra: ParamMap): GBTRegressionModel = { - copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra).setParent(parent) + copyValues(new GBTRegressionModel(uid, _trees, _treeWeights, numFeatures), + extra).setParent(parent) } override def toString: String = { @@ -204,7 +215,8 @@ private[ml] object GBTRegressionModel { def fromOld( oldModel: OldGBTModel, parent: GBTRegressor, - categoricalFeatures: Map[Int, Int]): GBTRegressionModel = { + categoricalFeatures: Map[Int, Int], + numFeatures: Int = -1): GBTRegressionModel = { require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).") val newTrees = oldModel.trees.map { tree => @@ -212,6 +224,6 @@ private[ml] object GBTRegressionModel { DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr") - new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights) + new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 78a67c5fda..a77e702141 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -293,6 +293,8 @@ class LinearRegressionModel private[ml] ( private var trainingSummary: Option[LinearRegressionTrainingSummary] = None + override val numFeatures: Int = weights.size + /** * Gets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is * thrown if `trainingSummary == None`. diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index ddb7214416..64fc17247c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -115,7 +115,7 @@ object RandomForestRegressor { final class RandomForestRegressionModel private[ml] ( override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], - val numFeatures: Int) + override val numFeatures: Int) extends PredictionModel[Vector, RandomForestRegressionModel] with TreeEnsembleModel with Serializable { @@ -187,13 +187,14 @@ private[ml] object RandomForestRegressionModel { def fromOld( oldModel: OldRandomForestModel, parent: RandomForestRegressor, - categoricalFeatures: Map[Int, Int]): RandomForestRegressionModel = { + categoricalFeatures: Map[Int, Int], + numFeatures: Int = -1): RandomForestRegressionModel = { require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).") val newTrees = oldModel.trees.map { tree => // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } - new RandomForestRegressionModel(parent.uid, newTrees, -1) + new RandomForestRegressionModel(parent.uid, newTrees, numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 4ac51a4754..c494556085 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -179,22 +179,28 @@ private[ml] object RandomForest extends Logging { } } + val numFeatures = metadata.numFeatures + parentUID match { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(uid, rootNode.toNode, strategy.getNumClasses) + new DecisionTreeClassificationModel(uid, rootNode.toNode, numFeatures, + strategy.getNumClasses) } } else { - topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode)) + topNodes.map { rootNode => + new DecisionTreeRegressionModel(uid, rootNode.toNode, numFeatures) + } } case None => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(rootNode.toNode, strategy.getNumClasses) + new DecisionTreeClassificationModel(rootNode.toNode, numFeatures, + strategy.getNumClasses) } } else { - topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode)) + topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode, numFeatures)) } } } |