diff options
author | sethah <seth.hendrickson16@gmail.com> | 2015-09-23 15:00:52 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-09-23 15:00:52 -0700 |
commit | 098be27ad53c485ee2fc7f5871c47f899020e87b (patch) | |
tree | 1e6fe63cc0bb8bd6088b4117bc1951fdd6c42507 /mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala | |
parent | a18208047f06a4244703c17023bb20cbe1f59d73 (diff) | |
download | spark-098be27ad53c485ee2fc7f5871c47f899020e87b.tar.gz spark-098be27ad53c485ee2fc7f5871c47f899020e87b.tar.bz2 spark-098be27ad53c485ee2fc7f5871c47f899020e87b.zip |
[SPARK-9715] [ML] Store numFeatures in all ML PredictionModel types
All prediction models should store `numFeatures` indicating the number of features the model was trained on. Default value of -1 added for backwards compatibility.
Author: sethah <seth.hendrickson16@gmail.com>
Closes #8675 from sethah/SPARK-9715.
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index b8eb49f9bd..a6f6d463bf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -107,6 +107,7 @@ object DecisionTreeClassifier { final class DecisionTreeClassificationModel private[ml] ( override val uid: String, override val rootNode: Node, + override val numFeatures: Int, override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] with DecisionTreeModel with Serializable { @@ -118,8 +119,8 @@ final class DecisionTreeClassificationModel private[ml] ( * Construct a decision tree classification model. * @param rootNode Root node of tree, with other nodes attached. */ - private[ml] def this(rootNode: Node, numClasses: Int) = - this(Identifiable.randomUID("dtc"), rootNode, numClasses) + private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) = + this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses) override protected def predict(features: Vector): Double = { rootNode.predictImpl(features).prediction @@ -141,7 +142,7 @@ final class DecisionTreeClassificationModel private[ml] ( } override def copy(extra: ParamMap): DecisionTreeClassificationModel = { - copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra) + copyValues(new DecisionTreeClassificationModel(uid, rootNode, numFeatures, numClasses), extra) .setParent(parent) } @@ -161,12 +162,14 @@ private[ml] object DecisionTreeClassificationModel { def fromOld( oldModel: OldDecisionTreeModel, parent: DecisionTreeClassifier, - categoricalFeatures: Map[Int, Int]): DecisionTreeClassificationModel = { + categoricalFeatures: Map[Int, Int], + numFeatures: Int = -1): DecisionTreeClassificationModel = { require(oldModel.algo == OldAlgo.Classification, s"Cannot convert non-classification DecisionTreeModel (old API) to" + s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc") - new DecisionTreeClassificationModel(uid, rootNode, -1) + // Can't infer number of features from old model, so default to -1 + new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1) } } |