aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2015-09-23 15:00:52 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-09-23 15:00:52 -0700
commit098be27ad53c485ee2fc7f5871c47f899020e87b (patch)
tree1e6fe63cc0bb8bd6088b4117bc1951fdd6c42507 /mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
parenta18208047f06a4244703c17023bb20cbe1f59d73 (diff)
downloadspark-098be27ad53c485ee2fc7f5871c47f899020e87b.tar.gz
spark-098be27ad53c485ee2fc7f5871c47f899020e87b.tar.bz2
spark-098be27ad53c485ee2fc7f5871c47f899020e87b.zip
[SPARK-9715] [ML] Store numFeatures in all ML PredictionModel types
All prediction models should store `numFeatures` indicating the number of features the model was trained on. Default value of -1 added for backwards compatibility. Author: sethah <seth.hendrickson16@gmail.com> Closes #8675 from sethah/SPARK-9715.
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala13
1 files changed, 8 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index b8eb49f9bd..a6f6d463bf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -107,6 +107,7 @@ object DecisionTreeClassifier {
final class DecisionTreeClassificationModel private[ml] (
override val uid: String,
override val rootNode: Node,
+ override val numFeatures: Int,
override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
with DecisionTreeModel with Serializable {
@@ -118,8 +119,8 @@ final class DecisionTreeClassificationModel private[ml] (
* Construct a decision tree classification model.
* @param rootNode Root node of tree, with other nodes attached.
*/
- private[ml] def this(rootNode: Node, numClasses: Int) =
- this(Identifiable.randomUID("dtc"), rootNode, numClasses)
+ private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
+ this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses)
override protected def predict(features: Vector): Double = {
rootNode.predictImpl(features).prediction
@@ -141,7 +142,7 @@ final class DecisionTreeClassificationModel private[ml] (
}
override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
- copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra)
+ copyValues(new DecisionTreeClassificationModel(uid, rootNode, numFeatures, numClasses), extra)
.setParent(parent)
}
@@ -161,12 +162,14 @@ private[ml] object DecisionTreeClassificationModel {
def fromOld(
oldModel: OldDecisionTreeModel,
parent: DecisionTreeClassifier,
- categoricalFeatures: Map[Int, Int]): DecisionTreeClassificationModel = {
+ categoricalFeatures: Map[Int, Int],
+ numFeatures: Int = -1): DecisionTreeClassificationModel = {
require(oldModel.algo == OldAlgo.Classification,
s"Cannot convert non-classification DecisionTreeModel (old API) to" +
s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
- new DecisionTreeClassificationModel(uid, rootNode, -1)
+ // Can't infer number of features from old model, so default to -1
+ new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1)
}
}