diff options
Diffstat (limited to 'mllib')
32 files changed, 2104 insertions, 263 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala index aa27a668f1..d7dee8fed2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala @@ -117,12 +117,12 @@ class AttributeGroup private ( case numeric: NumericAttribute => // Skip default numeric attributes. if (numeric.withoutIndex != NumericAttribute.defaultAttr) { - numericMetadata += numeric.toMetadata(withType = false) + numericMetadata += numeric.toMetadataImpl(withType = false) } case nominal: NominalAttribute => - nominalMetadata += nominal.toMetadata(withType = false) + nominalMetadata += nominal.toMetadataImpl(withType = false) case binary: BinaryAttribute => - binaryMetadata += binary.toMetadata(withType = false) + binaryMetadata += binary.toMetadataImpl(withType = false) } val attrBldr = new MetadataBuilder if (numericMetadata.nonEmpty) { @@ -151,7 +151,7 @@ class AttributeGroup private ( } /** Converts to ML metadata */ - def toMetadata: Metadata = toMetadata(Metadata.empty) + def toMetadata(): Metadata = toMetadata(Metadata.empty) /** Converts to a StructField with some existing metadata. */ def toStructField(existingMetadata: Metadata): StructField = { @@ -159,7 +159,7 @@ class AttributeGroup private ( } /** Converts to a StructField. */ - def toStructField: StructField = toStructField(Metadata.empty) + def toStructField(): StructField = toStructField(Metadata.empty) override def equals(other: Any): Boolean = { other match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala index 00b7566aab..5717d6ec2e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala @@ -68,21 +68,32 @@ sealed abstract class Attribute extends Serializable { * Converts this attribute to [[Metadata]]. * @param withType whether to include the type info */ - private[attribute] def toMetadata(withType: Boolean): Metadata + private[attribute] def toMetadataImpl(withType: Boolean): Metadata /** * Converts this attribute to [[Metadata]]. For numeric attributes, the type info is excluded to * save space, because numeric type is the default attribute type. For nominal and binary * attributes, the type info is included. */ - private[attribute] def toMetadata(): Metadata = { + private[attribute] def toMetadataImpl(): Metadata = { if (attrType == AttributeType.Numeric) { - toMetadata(withType = false) + toMetadataImpl(withType = false) } else { - toMetadata(withType = true) + toMetadataImpl(withType = true) } } + /** Converts to ML metadata with some existing metadata. */ + def toMetadata(existingMetadata: Metadata): Metadata = { + new MetadataBuilder() + .withMetadata(existingMetadata) + .putMetadata(AttributeKeys.ML_ATTR, toMetadataImpl()) + .build() + } + + /** Converts to ML metadata */ + def toMetadata(): Metadata = toMetadata(Metadata.empty) + /** * Converts to a [[StructField]] with some existing metadata. * @param existingMetadata existing metadata to carry over @@ -90,7 +101,7 @@ sealed abstract class Attribute extends Serializable { def toStructField(existingMetadata: Metadata): StructField = { val newMetadata = new MetadataBuilder() .withMetadata(existingMetadata) - .putMetadata(AttributeKeys.ML_ATTR, withoutName.withoutIndex.toMetadata()) + .putMetadata(AttributeKeys.ML_ATTR, withoutName.withoutIndex.toMetadataImpl()) .build() StructField(name.get, DoubleType, nullable = false, newMetadata) } @@ -98,7 +109,7 @@ sealed abstract class Attribute extends Serializable { /** Converts to a [[StructField]]. */ def toStructField(): StructField = toStructField(Metadata.empty) - override def toString: String = toMetadata(withType = true).toString + override def toString: String = toMetadataImpl(withType = true).toString } /** Trait for ML attribute factories. */ @@ -210,7 +221,7 @@ class NumericAttribute private[ml] ( override def isNominal: Boolean = false /** Convert this attribute to metadata. */ - private[attribute] override def toMetadata(withType: Boolean): Metadata = { + override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = { import org.apache.spark.ml.attribute.AttributeKeys._ val bldr = new MetadataBuilder() if (withType) bldr.putString(TYPE, attrType.name) @@ -353,6 +364,20 @@ class NominalAttribute private[ml] ( /** Copy without the `numValues`. */ def withoutNumValues: NominalAttribute = copy(numValues = None) + /** + * Get the number of values, either from `numValues` or from `values`. + * Return None if unknown. + */ + def getNumValues: Option[Int] = { + if (numValues.nonEmpty) { + numValues + } else if (values.nonEmpty) { + Some(values.get.length) + } else { + None + } + } + /** Creates a copy of this attribute with optional changes. */ private def copy( name: Option[String] = name, @@ -363,7 +388,7 @@ class NominalAttribute private[ml] ( new NominalAttribute(name, index, isOrdinal, numValues, values) } - private[attribute] override def toMetadata(withType: Boolean): Metadata = { + override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = { import org.apache.spark.ml.attribute.AttributeKeys._ val bldr = new MetadataBuilder() if (withType) bldr.putString(TYPE, attrType.name) @@ -465,7 +490,7 @@ class BinaryAttribute private[ml] ( new BinaryAttribute(name, index, values) } - private[attribute] override def toMetadata(withType: Boolean): Metadata = { + override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = { import org.apache.spark.ml.attribute.AttributeKeys._ val bldr = new MetadataBuilder if (withType) bldr.putString(TYPE, attrType.name) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala new file mode 100644 index 0000000000..3855e396b5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.impl.estimator.{Predictor, PredictionModel} +import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.tree.{DecisionTreeModel, Node} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm + * for classification. + * It supports both binary and multiclass labels, as well as both continuous and categorical + * features. + */ +@AlphaComponent +final class DecisionTreeClassifier + extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] + with DecisionTreeParams + with TreeClassifierParams { + + // Override parameter setters from parent trait for Java API compatibility. + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = + super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = + super.setCheckpointInterval(value) + + override def setImpurity(value: String): this.type = super.setImpurity(value) + + override protected def train( + dataset: DataFrame, + paramMap: ParamMap): DecisionTreeClassificationModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) + val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match { + case Some(n: Int) => n + case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" + + s" with invalid label column, without the number of classes specified.") + // TODO: Automatically index labels. + } + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + val strategy = getOldStrategy(categoricalFeatures, numClasses) + val oldModel = OldDecisionTree.train(oldDataset, strategy) + DecisionTreeClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + } + + /** (private[ml]) Create a Strategy instance to use with the old API. */ + override private[ml] def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int): OldStrategy = { + val strategy = super.getOldStrategy(categoricalFeatures, numClasses) + strategy.algo = OldAlgo.Classification + strategy.setImpurity(getOldImpurity) + strategy + } +} + +object DecisionTreeClassifier { + /** Accessor for supported impurities */ + final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities +} + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for classification. + * It supports both binary and multiclass labels, as well as both continuous and categorical + * features. + */ +@AlphaComponent +final class DecisionTreeClassificationModel private[ml] ( + override val parent: DecisionTreeClassifier, + override val fittingParamMap: ParamMap, + override val rootNode: Node) + extends PredictionModel[Vector, DecisionTreeClassificationModel] + with DecisionTreeModel with Serializable { + + require(rootNode != null, + "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") + + override protected def predict(features: Vector): Double = { + rootNode.predict(features) + } + + override protected def copy(): DecisionTreeClassificationModel = { + val m = new DecisionTreeClassificationModel(parent, fittingParamMap, rootNode) + Params.inheritValues(this.extractParamMap(), this, m) + m + } + + override def toString: String = { + s"DecisionTreeClassificationModel of depth $depth with $numNodes nodes" + } + + /** (private[ml]) Convert to a model in the old API */ + private[ml] def toOld: OldDecisionTreeModel = { + new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification) + } +} + +private[ml] object DecisionTreeClassificationModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldDecisionTreeModel, + parent: DecisionTreeClassifier, + fittingParamMap: ParamMap, + categoricalFeatures: Map[Int, Int]): DecisionTreeClassificationModel = { + require(oldModel.algo == OldAlgo.Classification, + s"Cannot convert non-classification DecisionTreeModel (old API) to" + + s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) + new DecisionTreeClassificationModel(parent, fittingParamMap, rootNode) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 4d960df357..23956c512c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -118,7 +118,7 @@ class StringIndexerModel private[ml] ( } val outputColName = map(outputCol) val metadata = NominalAttribute.defaultAttr - .withName(outputColName).withValues(labels).toStructField().metadata + .withName(outputColName).withValues(labels).toMetadata() dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala new file mode 100644 index 0000000000..6f4509f03d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.impl.tree + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.impl.estimator.PredictorParams +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.impurity.{Gini => OldGini, Entropy => OldEntropy, + Impurity => OldImpurity, Variance => OldVariance} + + +/** + * :: DeveloperApi :: + * Parameters for Decision Tree-based algorithms. + * + * Note: Marked as private and DeveloperApi since this may be made public in the future. + */ +@DeveloperApi +private[ml] trait DecisionTreeParams extends PredictorParams { + + /** + * Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * (default = 5) + * @group param + */ + final val maxDepth: IntParam = + new IntParam(this, "maxDepth", "Maximum depth of the tree." + + " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.") + + /** + * Maximum number of bins used for discretizing continuous features and for choosing how to split + * on features at each node. More bins give higher granularity. + * Must be >= 2 and >= number of categories in any categorical feature. + * (default = 32) + * @group param + */ + final val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" + + " discretizing continuous features. Must be >=2 and >= number of categories for any" + + " categorical feature.") + + /** + * Minimum number of instances each child must have after split. + * If a split causes the left or right child to have fewer than minInstancesPerNode, + * the split will be discarded as invalid. + * Should be >= 1. + * (default = 1) + * @group param + */ + final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" + + " number of instances each child must have after split. If a split causes the left or right" + + " child to have fewer than minInstancesPerNode, the split will be discarded as invalid." + + " Should be >= 1.") + + /** + * Minimum information gain for a split to be considered at a tree node. + * (default = 0.0) + * @group param + */ + final val minInfoGain: DoubleParam = new DoubleParam(this, "minInfoGain", + "Minimum information gain for a split to be considered at a tree node.") + + /** + * Maximum memory in MB allocated to histogram aggregation. + * (default = 256 MB) + * @group expertParam + */ + final val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB", + "Maximum memory in MB allocated to histogram aggregation.") + + /** + * If false, the algorithm will pass trees to executors to match instances with nodes. + * If true, the algorithm will cache node IDs for each instance. + * Caching can speed up training of deeper trees. + * (default = false) + * @group expertParam + */ + final val cacheNodeIds: BooleanParam = new BooleanParam(this, "cacheNodeIds", "If false, the" + + " algorithm will pass trees to executors to match instances with nodes. If true, the" + + " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" + + " trees.") + + /** + * Specifies how often to checkpoint the cached node IDs. + * E.g. 10 means that the cache will get checkpointed every 10 iterations. + * This is only used if cacheNodeIds is true and if the checkpoint directory is set in + * [[org.apache.spark.SparkContext]]. + * Must be >= 1. + * (default = 10) + * @group expertParam + */ + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" + + " how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get" + + " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" + + " checkpoint directory is set in the SparkContext. Must be >= 1.") + + setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, + maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) + + /** @group setParam */ + def setMaxDepth(value: Int): this.type = { + require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value") + set(maxDepth, value) + this.asInstanceOf[this.type] + } + + /** @group getParam */ + def getMaxDepth: Int = getOrDefault(maxDepth) + + /** @group setParam */ + def setMaxBins(value: Int): this.type = { + require(value >= 2, s"maxBins parameter must be >= 2. Given bad value: $value") + set(maxBins, value) + this + } + + /** @group getParam */ + def getMaxBins: Int = getOrDefault(maxBins) + + /** @group setParam */ + def setMinInstancesPerNode(value: Int): this.type = { + require(value >= 1, s"minInstancesPerNode parameter must be >= 1. Given bad value: $value") + set(minInstancesPerNode, value) + this + } + + /** @group getParam */ + def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode) + + /** @group setParam */ + def setMinInfoGain(value: Double): this.type = { + set(minInfoGain, value) + this + } + + /** @group getParam */ + def getMinInfoGain: Double = getOrDefault(minInfoGain) + + /** @group expertSetParam */ + def setMaxMemoryInMB(value: Int): this.type = { + require(value > 0, s"maxMemoryInMB parameter must be > 0. Given bad value: $value") + set(maxMemoryInMB, value) + this + } + + /** @group expertGetParam */ + def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB) + + /** @group expertSetParam */ + def setCacheNodeIds(value: Boolean): this.type = { + set(cacheNodeIds, value) + this + } + + /** @group expertGetParam */ + def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds) + + /** @group expertSetParam */ + def setCheckpointInterval(value: Int): this.type = { + require(value >= 1, s"checkpointInterval parameter must be >= 1. Given bad value: $value") + set(checkpointInterval, value) + this + } + + /** @group expertGetParam */ + def getCheckpointInterval: Int = getOrDefault(checkpointInterval) + + /** + * Create a Strategy instance to use with the old API. + * NOTE: The caller should set impurity and subsamplingRate (which is set to 1.0, + * the default for single trees). + */ + private[ml] def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int): OldStrategy = { + val strategy = OldStrategy.defaultStategy(OldAlgo.Classification) + strategy.checkpointInterval = getCheckpointInterval + strategy.maxBins = getMaxBins + strategy.maxDepth = getMaxDepth + strategy.maxMemoryInMB = getMaxMemoryInMB + strategy.minInfoGain = getMinInfoGain + strategy.minInstancesPerNode = getMinInstancesPerNode + strategy.useNodeIdCache = getCacheNodeIds + strategy.numClasses = numClasses + strategy.categoricalFeaturesInfo = categoricalFeatures + strategy.subsamplingRate = 1.0 // default for individual trees + strategy + } +} + +/** + * (private trait) Parameters for Decision Tree-based classification algorithms. + */ +private[ml] trait TreeClassifierParams extends Params { + + /** + * Criterion used for information gain calculation (case-insensitive). + * Supported: "entropy" and "gini". + * (default = gini) + * @group param + */ + val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + + " information gain calculation (case-insensitive). Supported options:" + + s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}") + + setDefault(impurity -> "gini") + + /** @group setParam */ + def setImpurity(value: String): this.type = { + val impurityStr = value.toLowerCase + require(TreeClassifierParams.supportedImpurities.contains(impurityStr), + s"Tree-based classifier was given unrecognized impurity: $value." + + s" Supported options: ${TreeClassifierParams.supportedImpurities.mkString(", ")}") + set(impurity, impurityStr) + this + } + + /** @group getParam */ + def getImpurity: String = getOrDefault(impurity) + + /** Convert new impurity to old impurity. */ + private[ml] def getOldImpurity: OldImpurity = { + getImpurity match { + case "entropy" => OldEntropy + case "gini" => OldGini + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException( + s"TreeClassifierParams was given unrecognized impurity: $impurity.") + } + } +} + +private[ml] object TreeClassifierParams { + // These options should be lowercase. + val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase) +} + +/** + * (private trait) Parameters for Decision Tree-based regression algorithms. + */ +private[ml] trait TreeRegressorParams extends Params { + + /** + * Criterion used for information gain calculation (case-insensitive). + * Supported: "variance". + * (default = variance) + * @group param + */ + val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + + " information gain calculation (case-insensitive). Supported options:" + + s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}") + + setDefault(impurity -> "variance") + + /** @group setParam */ + def setImpurity(value: String): this.type = { + val impurityStr = value.toLowerCase + require(TreeRegressorParams.supportedImpurities.contains(impurityStr), + s"Tree-based regressor was given unrecognized impurity: $value." + + s" Supported options: ${TreeRegressorParams.supportedImpurities.mkString(", ")}") + set(impurity, impurityStr) + this + } + + /** @group getParam */ + def getImpurity: String = getOrDefault(impurity) + + /** Convert new impurity to old impurity. */ + protected def getOldImpurity: OldImpurity = { + getImpurity match { + case "variance" => OldVariance + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException( + s"TreeRegressorParams was given unrecognized impurity: $impurity") + } + } +} + +private[ml] object TreeRegressorParams { + // These options should be lowercase. + val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala index b45bd1499b..ac75e9de1a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/package.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala @@ -32,6 +32,18 @@ package org.apache.spark * @groupname getParam Parameter getters * @groupprio getParam 6 * + * @groupname expertParam (expert-only) Parameters + * @groupdesc expertParam A list of advanced, expert-only (hyper-)parameter keys this algorithm can + * take. Users can set and get the parameter values through setters and getters, + * respectively. + * @groupprio expertParam 7 + * + * @groupname expertSetParam (expert-only) Parameter setters + * @groupprio expertSetParam 8 + * + * @groupname expertGetParam (expert-only) Parameter getters + * @groupprio expertGetParam 9 + * * @groupname Ungrouped Members * @groupprio Ungrouped 0 */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 849c60433c..ddc5907e7f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -296,8 +296,9 @@ private[spark] object Params { paramMap: ParamMap, parent: E, child: M): Unit = { + val childParams = child.params.map(_.name).toSet parent.params.foreach { param => - if (paramMap.contains(param)) { + if (paramMap.contains(param) && childParams.contains(param.name)) { child.set(child.getParam(param.name), paramMap(param)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala new file mode 100644 index 0000000000..49a8b77acf --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} +import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.tree.{DecisionTreeModel, Node} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm + * for regression. + * It supports both continuous and categorical features. + */ +@AlphaComponent +final class DecisionTreeRegressor + extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] + with DecisionTreeParams + with TreeRegressorParams { + + // Override parameter setters from parent trait for Java API compatibility. + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = + super.setCheckpointInterval(value) + + override def setImpurity(value: String): this.type = super.setImpurity(value) + + override protected def train( + dataset: DataFrame, + paramMap: ParamMap): DecisionTreeRegressionModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + val strategy = getOldStrategy(categoricalFeatures) + val oldModel = OldDecisionTree.train(oldDataset, strategy) + DecisionTreeRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + } + + /** (private[ml]) Create a Strategy instance to use with the old API. */ + private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = { + val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0) + strategy.algo = OldAlgo.Regression + strategy.setImpurity(getOldImpurity) + strategy + } +} + +object DecisionTreeRegressor { + /** Accessor for supported impurities */ + final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities +} + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for regression. + * It supports both continuous and categorical features. + * @param rootNode Root of the decision tree + */ +@AlphaComponent +final class DecisionTreeRegressionModel private[ml] ( + override val parent: DecisionTreeRegressor, + override val fittingParamMap: ParamMap, + override val rootNode: Node) + extends PredictionModel[Vector, DecisionTreeRegressionModel] + with DecisionTreeModel with Serializable { + + require(rootNode != null, + "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") + + override protected def predict(features: Vector): Double = { + rootNode.predict(features) + } + + override protected def copy(): DecisionTreeRegressionModel = { + val m = new DecisionTreeRegressionModel(parent, fittingParamMap, rootNode) + Params.inheritValues(this.extractParamMap(), this, m) + m + } + + override def toString: String = { + s"DecisionTreeRegressionModel of depth $depth with $numNodes nodes" + } + + /** Convert to a model in the old API */ + private[ml] def toOld: OldDecisionTreeModel = { + new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression) + } +} + +private[ml] object DecisionTreeRegressionModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldDecisionTreeModel, + parent: DecisionTreeRegressor, + fittingParamMap: ParamMap, + categoricalFeatures: Map[Int, Int]): DecisionTreeRegressionModel = { + require(oldModel.algo == OldAlgo.Regression, + s"Cannot convert non-regression DecisionTreeModel (old API) to" + + s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}") + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) + new DecisionTreeRegressionModel(parent, fittingParamMap, rootNode) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala new file mode 100644 index 0000000000..d6e2203d9f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree + +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats, + Node => OldNode, Predict => OldPredict} + + +/** + * Decision tree node interface. + */ +sealed abstract class Node extends Serializable { + + // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree + // code into the new API and deprecate the old API. + + /** Prediction this node makes (or would make, if it is an internal node) */ + def prediction: Double + + /** Impurity measure at this node (for training data) */ + def impurity: Double + + /** Recursive prediction helper method */ + private[ml] def predict(features: Vector): Double = prediction + + /** + * Get the number of nodes in tree below this node, including leaf nodes. + * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2. + */ + private[tree] def numDescendants: Int + + /** + * Recursive print function. + * @param indentFactor The number of spaces to add to each level of indentation. + */ + private[tree] def subtreeToString(indentFactor: Int = 0): String + + /** + * Get depth of tree from this node. + * E.g.: Depth 0 means this is a leaf node. Depth 1 means 1 internal and 2 leaf nodes. + */ + private[tree] def subtreeDepth: Int + + /** + * Create a copy of this node in the old Node format, recursively creating child nodes as needed. + * @param id Node ID using old format IDs + */ + private[ml] def toOld(id: Int): OldNode +} + +private[ml] object Node { + + /** + * Create a new Node from the old Node format, recursively creating child nodes as needed. + */ + def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = { + if (oldNode.isLeaf) { + // TODO: Once the implementation has been moved to this API, then include sufficient + // statistics here. + new LeafNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity) + } else { + val gain = if (oldNode.stats.nonEmpty) { + oldNode.stats.get.gain + } else { + 0.0 + } + new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity, + gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures), + rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures), + split = Split.fromOld(oldNode.split.get, categoricalFeatures)) + } + } +} + +/** + * Decision tree leaf node. + * @param prediction Prediction this node makes + * @param impurity Impurity measure at this node (for training data) + */ +final class LeafNode private[ml] ( + override val prediction: Double, + override val impurity: Double) extends Node { + + override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)" + + override private[ml] def predict(features: Vector): Double = prediction + + override private[tree] def numDescendants: Int = 0 + + override private[tree] def subtreeToString(indentFactor: Int = 0): String = { + val prefix: String = " " * indentFactor + prefix + s"Predict: $prediction\n" + } + + override private[tree] def subtreeDepth: Int = 0 + + override private[ml] def toOld(id: Int): OldNode = { + // NOTE: We do NOT store 'prob' in the new API currently. + new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = true, + None, None, None, None) + } +} + +/** + * Internal Decision Tree node. + * @param prediction Prediction this node would make if it were a leaf node + * @param impurity Impurity measure at this node (for training data) + * @param gain Information gain value. + * Values < 0 indicate missing values; this quirk will be removed with future updates. + * @param leftChild Left-hand child node + * @param rightChild Right-hand child node + * @param split Information about the test used to split to the left or right child. + */ +final class InternalNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + val gain: Double, + val leftChild: Node, + val rightChild: Node, + val split: Split) extends Node { + + override def toString: String = { + s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" + } + + override private[ml] def predict(features: Vector): Double = { + if (split.shouldGoLeft(features)) { + leftChild.predict(features) + } else { + rightChild.predict(features) + } + } + + override private[tree] def numDescendants: Int = { + 2 + leftChild.numDescendants + rightChild.numDescendants + } + + override private[tree] def subtreeToString(indentFactor: Int = 0): String = { + val prefix: String = " " * indentFactor + prefix + s"If (${InternalNode.splitToString(split, left=true)})\n" + + leftChild.subtreeToString(indentFactor + 1) + + prefix + s"Else (${InternalNode.splitToString(split, left=false)})\n" + + rightChild.subtreeToString(indentFactor + 1) + } + + override private[tree] def subtreeDepth: Int = { + 1 + math.max(leftChild.subtreeDepth, rightChild.subtreeDepth) + } + + override private[ml] def toOld(id: Int): OldNode = { + assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API" + + " since the old API does not support deep trees.") + // NOTE: We do NOT store 'prob' in the new API currently. + new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = false, + Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))), + Some(rightChild.toOld(OldNode.rightChildIndex(id))), + Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity, + new OldPredict(leftChild.prediction, prob = 0.0), + new OldPredict(rightChild.prediction, prob = 0.0)))) + } +} + +private object InternalNode { + + /** + * Helper method for [[Node.subtreeToString()]]. + * @param split Split to print + * @param left Indicates whether this is the part of the split going to the left, + * or that going to the right. + */ + private def splitToString(split: Split, left: Boolean): String = { + val featureStr = s"feature ${split.featureIndex}" + split match { + case contSplit: ContinuousSplit => + if (left) { + s"$featureStr <= ${contSplit.threshold}" + } else { + s"$featureStr > ${contSplit.threshold}" + } + case catSplit: CategoricalSplit => + val categoriesStr = catSplit.getLeftCategories.mkString("{", ",", "}") + if (left) { + s"$featureStr in $categoriesStr" + } else { + s"$featureStr not in $categoriesStr" + } + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala new file mode 100644 index 0000000000..cb940f6299 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree + +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType} +import org.apache.spark.mllib.tree.model.{Split => OldSplit} + + +/** + * Interface for a "Split," which specifies a test made at a decision tree node + * to choose the left or right path. + */ +sealed trait Split extends Serializable { + + /** Index of feature which this split tests */ + def featureIndex: Int + + /** Return true (split to left) or false (split to right) */ + private[ml] def shouldGoLeft(features: Vector): Boolean + + /** Convert to old Split format */ + private[tree] def toOld: OldSplit +} + +private[ml] object Split { + + def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = { + oldSplit.featureType match { + case OldFeatureType.Categorical => + new CategoricalSplit(featureIndex = oldSplit.feature, + leftCategories = oldSplit.categories.toArray, categoricalFeatures(oldSplit.feature)) + case OldFeatureType.Continuous => + new ContinuousSplit(featureIndex = oldSplit.feature, threshold = oldSplit.threshold) + } + } +} + +/** + * Split which tests a categorical feature. + * @param featureIndex Index of the feature to test + * @param leftCategories If the feature value is in this set of categories, then the split goes + * left. Otherwise, it goes right. + * @param numCategories Number of categories for this feature. + */ +final class CategoricalSplit( + override val featureIndex: Int, + leftCategories: Array[Double], + private val numCategories: Int) + extends Split { + + require(leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" + + s" (should be in range [0, $numCategories)): ${leftCategories.mkString(",")}") + + /** + * If true, then "categories" is the set of categories for splitting to the left, and vice versa. + */ + private val isLeft: Boolean = leftCategories.length <= numCategories / 2 + + /** Set of categories determining the splitting rule, along with [[isLeft]]. */ + private val categories: Set[Double] = { + if (isLeft) { + leftCategories.toSet + } else { + setComplement(leftCategories.toSet) + } + } + + override private[ml] def shouldGoLeft(features: Vector): Boolean = { + if (isLeft) { + categories.contains(features(featureIndex)) + } else { + !categories.contains(features(featureIndex)) + } + } + + override def equals(o: Any): Boolean = { + o match { + case other: CategoricalSplit => featureIndex == other.featureIndex && + isLeft == other.isLeft && categories == other.categories + case _ => false + } + } + + override private[tree] def toOld: OldSplit = { + val oldCats = if (isLeft) { + categories + } else { + setComplement(categories) + } + OldSplit(featureIndex, threshold = 0.0, OldFeatureType.Categorical, oldCats.toList) + } + + /** Get sorted categories which split to the left */ + def getLeftCategories: Array[Double] = { + val cats = if (isLeft) categories else setComplement(categories) + cats.toArray.sorted + } + + /** Get sorted categories which split to the right */ + def getRightCategories: Array[Double] = { + val cats = if (isLeft) setComplement(categories) else categories + cats.toArray.sorted + } + + /** [0, numCategories) \ cats */ + private def setComplement(cats: Set[Double]): Set[Double] = { + Range(0, numCategories).map(_.toDouble).filter(cat => !cats.contains(cat)).toSet + } +} + +/** + * Split which tests a continuous feature. + * @param featureIndex Index of the feature to test + * @param threshold If the feature value is <= this threshold, then the split goes left. + * Otherwise, it goes right. + */ +final class ContinuousSplit(override val featureIndex: Int, val threshold: Double) extends Split { + + override private[ml] def shouldGoLeft(features: Vector): Boolean = { + features(featureIndex) <= threshold + } + + override def equals(o: Any): Boolean = { + o match { + case other: ContinuousSplit => + featureIndex == other.featureIndex && threshold == other.threshold + case _ => + false + } + } + + override private[tree] def toOld: OldSplit = { + OldSplit(featureIndex, threshold, OldFeatureType.Continuous, List.empty[Double]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala new file mode 100644 index 0000000000..8e3bc3849d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree + +import org.apache.spark.annotation.AlphaComponent + + +/** + * :: AlphaComponent :: + * + * Abstraction for Decision Tree models. + * + * TODO: Add support for predicting probabilities and raw predictions + */ +@AlphaComponent +trait DecisionTreeModel { + + /** Root of the decision tree */ + def rootNode: Node + + /** Number of nodes in tree, including leaf nodes. */ + def numNodes: Int = { + 1 + rootNode.numDescendants + } + + /** + * Depth of the tree. + * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. + */ + lazy val depth: Int = { + rootNode.subtreeDepth + } + + /** Summary of the model */ + override def toString: String = { + // Implementing classes should generally override this method to be more descriptive. + s"DecisionTreeModel of depth $depth with $numNodes nodes" + } + + /** Full description of model */ + def toDebugString: String = { + val header = toString + "\n" + header + rootNode.subtreeToString(2) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala new file mode 100644 index 0000000000..c84c8b4eb7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import scala.collection.immutable.HashMap + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, BinaryAttribute, NominalAttribute, + NumericAttribute} +import org.apache.spark.sql.types.StructField + + +/** + * :: Experimental :: + * + * Helper utilities for tree-based algorithms + */ +@Experimental +object MetadataUtils { + + /** + * Examine a schema to identify the number of classes in a label column. + * Returns None if the number of labels is not specified, or if the label column is continuous. + */ + def getNumClasses(labelSchema: StructField): Option[Int] = { + Attribute.fromStructField(labelSchema) match { + case numAttr: NumericAttribute => None + case binAttr: BinaryAttribute => Some(2) + case nomAttr: NominalAttribute => nomAttr.getNumValues + } + } + + /** + * Examine a schema to identify categorical (Binary and Nominal) features. + * + * @param featuresSchema Schema of the features column. + * If a feature does not have metadata, it is assumed to be continuous. + * If a feature is Nominal, then it must have the number of values + * specified. + * @return Map: feature index --> number of categories. + * The map's set of keys will be the set of categorical feature indices. + */ + def getCategoricalFeatures(featuresSchema: StructField): Map[Int, Int] = { + val metadata = AttributeGroup.fromStructField(featuresSchema) + if (metadata.attributes.isEmpty) { + HashMap.empty[Int, Int] + } else { + metadata.attributes.get.zipWithIndex.flatMap { case (attr, idx) => + if (attr == null) { + Iterator() + } else { + attr match { + case numAttr: NumericAttribute => Iterator() + case binAttr: BinaryAttribute => Iterator(idx -> 2) + case nomAttr: NominalAttribute => + nomAttr.getNumValues match { + case Some(numValues: Int) => Iterator(idx -> numValues) + case None => throw new IllegalArgumentException(s"Feature $idx is marked as" + + " Nominal (categorical), but it does not have the number of values specified.") + } + } + } + }.toMap + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b9d0c56dd1..dfe3a0b691 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -1147,7 +1147,10 @@ object DecisionTree extends Serializable with Logging { } } - assert(splits.length > 0) + // TODO: Do not fail; just ignore the useless feature. + assert(splits.length > 0, + s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." + + " Please remove this feature and then try again.") // set number of splits accordingly metadata.setNumSplits(featureIndex, splits.length) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index c02c79f094..0e31c7ed58 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -81,11 +81,11 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) /** * Method to validate a gradient boosting model * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @param validationInput Validation dataset: - RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - Should be different from and follow the same distribution as input. - e.g., these two datasets could be created from an original dataset - by using [[org.apache.spark.rdd.RDD.randomSplit()]] + * @param validationInput Validation dataset. + * This dataset should be different from the training dataset, + * but it should follow the same distribution. + * E.g., these two datasets could be created from an original dataset + * by using [[org.apache.spark.rdd.RDD.randomSplit()]] * @return a gradient boosted trees model that can be used for prediction */ def runWithValidation( @@ -194,8 +194,6 @@ object GradientBoostedTrees extends Logging { val firstTreeWeight = 1.0 baseLearners(0) = firstTreeModel baseLearnerWeights(0) = firstTreeWeight - val startingModel = new GradientBoostedTreesModel( - Regression, Array(firstTreeModel), baseLearnerWeights.slice(0, 1)) var predError: RDD[(Double, Double)] = GradientBoostedTreesModel. computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index db01f2e229..055e60c7d9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -249,7 +249,7 @@ private class RandomForest ( nodeIdCache.get.deleteAllCheckpoints() } catch { case e:IOException => - logWarning(s"delete all chackpoints failed. Error reason: ${e.getMessage}") + logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}") } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 664c8df019..2d6b01524f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -89,14 +89,14 @@ object BoostingStrategy { * @return Configuration for boosting algorithm */ def defaultParams(algo: Algo): BoostingStrategy = { - val treeStragtegy = Strategy.defaultStategy(algo) - treeStragtegy.maxDepth = 3 + val treeStrategy = Strategy.defaultStategy(algo) + treeStrategy.maxDepth = 3 algo match { case Algo.Classification => - treeStragtegy.numClasses = 2 - new BoostingStrategy(treeStragtegy, LogLoss) + treeStrategy.numClasses = 2 + new BoostingStrategy(treeStrategy, LogLoss) case Algo.Regression => - new BoostingStrategy(treeStragtegy, SquaredError) + new BoostingStrategy(treeStrategy, SquaredError) case _ => throw new IllegalArgumentException(s"$algo is not supported by boosting.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala index 6f570b4e09..2bdef73c4a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel -import org.apache.spark.rdd.RDD + /** * :: DeveloperApi :: @@ -45,9 +45,8 @@ object AbsoluteError extends Loss { if (label - prediction < 0) 1.0 else -1.0 } - override def computeError(prediction: Double, label: Double): Double = { + override private[mllib] def computeError(prediction: Double, label: Double): Double = { val err = label - prediction math.abs(err) } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index 24ee9f3d51..778c24526d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD + /** * :: DeveloperApi :: @@ -47,10 +47,9 @@ object LogLoss extends Loss { - 4.0 * label / (1.0 + math.exp(2.0 * label * prediction)) } - override def computeError(prediction: Double, label: Double): Double = { + override private[mllib] def computeError(prediction: Double, label: Double): Double = { val margin = 2.0 * label * prediction // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable. 2.0 * MLUtils.log1pExp(-margin) } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index d3b82b752f..64ffccbce0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -22,6 +22,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD + /** * :: DeveloperApi :: * Trait for adding "pluggable" loss functions for the gradient boosting algorithm. @@ -57,6 +58,5 @@ trait Loss extends Serializable { * @param label True label. * @return Measure of model error on datapoint. */ - def computeError(prediction: Double, label: Double): Double - + private[mllib] def computeError(prediction: Double, label: Double): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala index 58857ae15e..a5582d3ef3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel -import org.apache.spark.rdd.RDD + /** * :: DeveloperApi :: @@ -45,9 +45,8 @@ object SquaredError extends Loss { 2.0 * (prediction - label) } - override def computeError(prediction: Double, label: Double): Double = { + override private[mllib] def computeError(prediction: Double, label: Double): Double = { val err = prediction - label err * err } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index c9bafd60fb..331af42853 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -113,11 +113,13 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable DecisionTreeModel.SaveLoadV1_0.save(sc, path, this) } - override protected def formatVersion: String = "1.0" + override protected def formatVersion: String = DecisionTreeModel.formatVersion } object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { + private[spark] def formatVersion: String = "1.0" + private[tree] object SaveLoadV1_0 { def thisFormatVersion: String = "1.0" diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 4f72bb8014..708ba04b56 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -175,7 +175,7 @@ class Node ( } } -private[tree] object Node { +private[spark] object Node { /** * Return a node with the given node id (but nothing else set). diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index fef3d2acb2..8341219bfa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -38,6 +38,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils + /** * :: Experimental :: * Represents a random forest model. @@ -47,7 +48,7 @@ import org.apache.spark.util.Utils */ @Experimental class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel]) - extends TreeEnsembleModel(algo, trees, Array.fill(trees.size)(1.0), + extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0), combiningStrategy = if (algo == Classification) Vote else Average) with Saveable { @@ -58,11 +59,13 @@ class RandomForestModel(override val algo: Algo, override val trees: Array[Decis RandomForestModel.SaveLoadV1_0.thisClassName) } - override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion + override protected def formatVersion: String = RandomForestModel.formatVersion } object RandomForestModel extends Loader[RandomForestModel] { + private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion + override def load(sc: SparkContext, path: String): RandomForestModel = { val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName @@ -102,15 +105,13 @@ class GradientBoostedTreesModel( extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) with Saveable { - require(trees.size == treeWeights.size) + require(trees.length == treeWeights.length) override def save(sc: SparkContext, path: String): Unit = { TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this, GradientBoostedTreesModel.SaveLoadV1_0.thisClassName) } - override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion - /** * Method to compute error or loss for every iteration of gradient boosting. * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] @@ -138,7 +139,7 @@ class GradientBoostedTreesModel( evaluationArray(0) = predictionAndError.values.mean() val broadcastTrees = sc.broadcast(trees) - (1 until numIterations).map { nTree => + (1 until numIterations).foreach { nTree => predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter => val currentTree = broadcastTrees.value(nTree) val currentTreeWeight = localTreeWeights(nTree) @@ -155,6 +156,7 @@ class GradientBoostedTreesModel( evaluationArray } + override protected def formatVersion: String = GradientBoostedTreesModel.formatVersion } object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { @@ -200,17 +202,17 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { loss: Loss): RDD[(Double, Double)] = { val newPredError = data.zip(predictionAndError).mapPartitions { iter => - iter.map { - case (lp, (pred, error)) => { - val newPred = pred + tree.predict(lp.features) * treeWeight - val newError = loss.computeError(newPred, lp.label) - (newPred, newError) - } + iter.map { case (lp, (pred, error)) => + val newPred = pred + tree.predict(lp.features) * treeWeight + val newError = loss.computeError(newPred, lp.label) + (newPred, newError) } } newPredError } + private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion + override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = { val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName @@ -340,12 +342,12 @@ private[tree] sealed class TreeEnsembleModel( } /** - * Get number of trees in forest. + * Get number of trees in ensemble. */ - def numTrees: Int = trees.size + def numTrees: Int = trees.length /** - * Get total number of nodes, summed over all trees in the forest. + * Get total number of nodes, summed over all trees in the ensemble. */ def totalNumNodes: Int = trees.map(_.numNodes).sum } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java new file mode 100644 index 0000000000..43b8787f9d --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.File; +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.util.Utils; + + +public class JavaDecisionTreeClassifierSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaDecisionTreeClassifierSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD<LabeledPoint> data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + + // This tests setters. Training with various options is tested in Scala. + DecisionTreeClassifier dt = new DecisionTreeClassifier() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (int i = 0; i < DecisionTreeClassifier.supportedImpurities().length; ++i) { + dt.setImpurity(DecisionTreeClassifier.supportedImpurities()[i]); + } + DecisionTreeClassificationModel model = dt.fit(dataFrame); + + model.transform(dataFrame); + model.numNodes(); + model.depth(); + model.toDebugString(); + + /* + // TODO: Add test once save/load are implemented. + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model3.save(sc.sc(), path); + DecisionTreeClassificationModel sameModel = + DecisionTreeClassificationModel.load(sc.sc(), path); + TreeTests.checkEqual(model3, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java new file mode 100644 index 0000000000..a3a339004f --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression; + +import java.io.File; +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.util.Utils; + + +public class JavaDecisionTreeRegressorSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaDecisionTreeRegressorSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD<LabeledPoint> data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + + // This tests setters. Training with various options is tested in Scala. + DecisionTreeRegressor dt = new DecisionTreeRegressor() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (int i = 0; i < DecisionTreeRegressor.supportedImpurities().length; ++i) { + dt.setImpurity(DecisionTreeRegressor.supportedImpurities()[i]); + } + DecisionTreeRegressionModel model = dt.fit(dataFrame); + + model.transform(dataFrame); + model.numNodes(); + model.depth(); + model.toDebugString(); + + /* + // TODO: Add test once save/load are implemented. + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model2.save(sc.sc(), path); + DecisionTreeRegressionModel sameModel = DecisionTreeRegressionModel.load(sc.sc(), path); + TreeTests.checkEqual(model2, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala index 0dcfe5a200..17ddd335de 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala @@ -44,7 +44,7 @@ class AttributeGroupSuite extends FunSuite { group("abc") } assert(group === AttributeGroup.fromMetadata(group.toMetadataImpl, group.name)) - assert(group === AttributeGroup.fromStructField(group.toStructField)) + assert(group === AttributeGroup.fromStructField(group.toStructField())) } test("attribute group without attributes") { @@ -54,7 +54,7 @@ class AttributeGroupSuite extends FunSuite { assert(group0.size === 10) assert(group0.attributes.isEmpty) assert(group0 === AttributeGroup.fromMetadata(group0.toMetadataImpl, group0.name)) - assert(group0 === AttributeGroup.fromStructField(group0.toStructField)) + assert(group0 === AttributeGroup.fromStructField(group0.toStructField())) val group1 = new AttributeGroup("item") assert(group1.name === "item") diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala index 6ec35b0365..3e1a7196e3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala @@ -36,9 +36,9 @@ class AttributeSuite extends FunSuite { assert(attr.max.isEmpty) assert(attr.std.isEmpty) assert(attr.sparsity.isEmpty) - assert(attr.toMetadata() === metadata) - assert(attr.toMetadata(withType = false) === metadata) - assert(attr.toMetadata(withType = true) === metadataWithType) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = false) === metadata) + assert(attr.toMetadataImpl(withType = true) === metadataWithType) assert(attr === Attribute.fromMetadata(metadata)) assert(attr === Attribute.fromMetadata(metadataWithType)) intercept[NoSuchElementException] { @@ -59,9 +59,9 @@ class AttributeSuite extends FunSuite { assert(!attr.isNominal) assert(attr.name === Some(name)) assert(attr.index === Some(index)) - assert(attr.toMetadata() === metadata) - assert(attr.toMetadata(withType = false) === metadata) - assert(attr.toMetadata(withType = true) === metadataWithType) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = false) === metadata) + assert(attr.toMetadataImpl(withType = true) === metadataWithType) assert(attr === Attribute.fromMetadata(metadata)) assert(attr === Attribute.fromMetadata(metadataWithType)) val field = attr.toStructField() @@ -81,7 +81,7 @@ class AttributeSuite extends FunSuite { assert(attr2.max === Some(1.0)) assert(attr2.std === Some(0.5)) assert(attr2.sparsity === Some(0.3)) - assert(attr2 === Attribute.fromMetadata(attr2.toMetadata())) + assert(attr2 === Attribute.fromMetadata(attr2.toMetadataImpl())) } test("bad numeric attributes") { @@ -105,9 +105,9 @@ class AttributeSuite extends FunSuite { assert(attr.values.isEmpty) assert(attr.numValues.isEmpty) assert(attr.isOrdinal.isEmpty) - assert(attr.toMetadata() === metadata) - assert(attr.toMetadata(withType = true) === metadata) - assert(attr.toMetadata(withType = false) === metadataWithoutType) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = true) === metadata) + assert(attr.toMetadataImpl(withType = false) === metadataWithoutType) assert(attr === Attribute.fromMetadata(metadata)) assert(attr === NominalAttribute.fromMetadata(metadataWithoutType)) intercept[NoSuchElementException] { @@ -135,9 +135,9 @@ class AttributeSuite extends FunSuite { assert(attr.values === Some(values)) assert(attr.indexOf("medium") === 1) assert(attr.getValue(1) === "medium") - assert(attr.toMetadata() === metadata) - assert(attr.toMetadata(withType = true) === metadata) - assert(attr.toMetadata(withType = false) === metadataWithoutType) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = true) === metadata) + assert(attr.toMetadataImpl(withType = false) === metadataWithoutType) assert(attr === Attribute.fromMetadata(metadata)) assert(attr === NominalAttribute.fromMetadata(metadataWithoutType)) assert(attr.withoutIndex === Attribute.fromStructField(attr.toStructField())) @@ -147,8 +147,8 @@ class AttributeSuite extends FunSuite { assert(attr2.index.isEmpty) assert(attr2.values.get === Array("small", "medium", "large", "x-large")) assert(attr2.indexOf("x-large") === 3) - assert(attr2 === Attribute.fromMetadata(attr2.toMetadata())) - assert(attr2 === NominalAttribute.fromMetadata(attr2.toMetadata(withType = false))) + assert(attr2 === Attribute.fromMetadata(attr2.toMetadataImpl())) + assert(attr2 === NominalAttribute.fromMetadata(attr2.toMetadataImpl(withType = false))) } test("bad nominal attributes") { @@ -168,9 +168,9 @@ class AttributeSuite extends FunSuite { assert(attr.name.isEmpty) assert(attr.index.isEmpty) assert(attr.values.isEmpty) - assert(attr.toMetadata() === metadata) - assert(attr.toMetadata(withType = true) === metadata) - assert(attr.toMetadata(withType = false) === metadataWithoutType) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = true) === metadata) + assert(attr.toMetadataImpl(withType = false) === metadataWithoutType) assert(attr === Attribute.fromMetadata(metadata)) assert(attr === BinaryAttribute.fromMetadata(metadataWithoutType)) intercept[NoSuchElementException] { @@ -196,9 +196,9 @@ class AttributeSuite extends FunSuite { assert(attr.name === Some(name)) assert(attr.index === Some(index)) assert(attr.values.get === values) - assert(attr.toMetadata() === metadata) - assert(attr.toMetadata(withType = true) === metadata) - assert(attr.toMetadata(withType = false) === metadataWithoutType) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = true) === metadata) + assert(attr.toMetadataImpl(withType = false) === metadataWithoutType) assert(attr === Attribute.fromMetadata(metadata)) assert(attr === BinaryAttribute.fromMetadata(metadataWithoutType)) assert(attr.withoutIndex === Attribute.fromStructField(attr.toStructField())) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala new file mode 100644 index 0000000000..af88595df5 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.scalatest.FunSuite + +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, + DecisionTreeSuite => OldDecisionTreeSuite} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext { + + import DecisionTreeClassifierSuite.compareAPIs + + private var categoricalDataPointsRDD: RDD[LabeledPoint] = _ + private var orderedLabeledPointsWithLabel0RDD: RDD[LabeledPoint] = _ + private var orderedLabeledPointsWithLabel1RDD: RDD[LabeledPoint] = _ + private var categoricalDataPointsForMulticlassRDD: RDD[LabeledPoint] = _ + private var continuousDataPointsForMulticlassRDD: RDD[LabeledPoint] = _ + private var categoricalDataPointsForMulticlassForOrderedFeaturesRDD: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + categoricalDataPointsRDD = + sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints()) + orderedLabeledPointsWithLabel0RDD = + sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()) + orderedLabeledPointsWithLabel1RDD = + sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()) + categoricalDataPointsForMulticlassRDD = + sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlass()) + continuousDataPointsForMulticlassRDD = + sc.parallelize(OldDecisionTreeSuite.generateContinuousDataPointsForMulticlass()) + categoricalDataPointsForMulticlassForOrderedFeaturesRDD = sc.parallelize( + OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// + + test("Binary classification stump with ordered categorical features") { + val dt = new DecisionTreeClassifier() + .setImpurity("gini") + .setMaxDepth(2) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3, 1-> 3) + val numClasses = 2 + compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses) + } + + test("Binary classification stump with fixed labels 0,1 for Entropy,Gini") { + val dt = new DecisionTreeClassifier() + .setMaxDepth(3) + .setMaxBins(100) + val numClasses = 2 + Array(orderedLabeledPointsWithLabel0RDD, orderedLabeledPointsWithLabel1RDD).foreach { rdd => + DecisionTreeClassifier.supportedImpurities.foreach { impurity => + dt.setImpurity(impurity) + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + } + } + + test("Multiclass classification stump with 3-ary (unordered) categorical features") { + val rdd = categoricalDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + val numClasses = 3 + val categoricalFeatures = Map(0 -> 3, 1 -> 3) + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("Binary classification stump with 1 continuous feature, to check off-by-1 error") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(3.0))) + val rdd = sc.parallelize(arr) + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("Binary classification stump with 2 continuous features") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0))))) + val rdd = sc.parallelize(arr) + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("Multiclass classification stump with unordered categorical features," + + " with just enough bins") { + val maxBins = 2 * (math.pow(2, 3 - 1).toInt - 1) // just enough bins to allow unordered features + val rdd = categoricalDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(maxBins) + val categoricalFeatures = Map(0 -> 3, 1 -> 3) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("Multiclass classification stump with continuous features") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("Multiclass classification stump with continuous + unordered categorical features") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("Multiclass classification stump with 10-ary (ordered) categorical features") { + val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 10, 1 -> 10) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("Multiclass classification tree with 10-ary (ordered) categorical features," + + " with just enough bins") { + val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(10) + val categoricalFeatures = Map(0 -> 10, 1 -> 10) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("split must satisfy min instances per node requirements") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) + val rdd = sc.parallelize(arr) + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("do not choose split that does not satisfy min instance per node requirements") { + // if a split does not satisfy min instances per node requirements, + // this split is invalid, even though the information gain of split is large. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0))) + val rdd = sc.parallelize(arr) + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxBins(2) + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val categoricalFeatures = Map(0 -> 2, 1-> 2) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("split must satisfy min info gain requirements") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) + val rdd = sc.parallelize(arr) + + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + .setMinInfoGain(1.0) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: Reinstate test once save/load are implemented + /* + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val oldModel = OldDecisionTreeSuite.createModel(OldAlgo.Classification) + val newModel = DecisionTreeClassificationModel.fromOld(oldModel) + + // Save model, load it back, and compare. + try { + newModel.save(sc, path) + val sameNewModel = DecisionTreeClassificationModel.load(sc, path) + TreeTests.checkEqual(newModel, sameNewModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + */ +} + +private[ml] object DecisionTreeClassifierSuite extends FunSuite { + + /** + * Train 2 decision trees on the given dataset, one using the old API and one using the new API. + * Convert the old tree to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + dt: DecisionTreeClassifier, + categoricalFeatures: Map[Int, Int], + numClasses: Int): Unit = { + val oldStrategy = dt.getOldStrategy(categoricalFeatures, numClasses) + val oldTree = OldDecisionTree.train(data, oldStrategy) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) + val newTree = dt.fit(newData) + // Use parent, fittingParamMap from newTree since these are not checked anyways. + val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(oldTree, newTree.parent, + newTree.fittingParamMap, categoricalFeatures) + TreeTests.checkEqual(oldTreeAsNew, newTree) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 81ef831c42..1b261b2643 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -228,7 +228,7 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext { } val attrGroup = new AttributeGroup("features", featureAttributes) val densePoints1WithMeta = - densePoints1.select(densePoints1("features").as("features", attrGroup.toMetadata)) + densePoints1.select(densePoints1("features").as("features", attrGroup.toMetadata())) val vectorIndexer = getIndexer.setMaxCategories(2) val model = vectorIndexer.fit(densePoints1WithMeta) // Check that ML metadata are preserved. diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala new file mode 100644 index 0000000000..2e57d4ce37 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.impl + +import scala.collection.JavaConverters._ + +import org.scalatest.FunSuite + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.tree.{DecisionTreeModel, InternalNode, LeafNode, Node} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SQLContext, DataFrame} + + +private[ml] object TreeTests extends FunSuite { + + /** + * Convert the given data to a DataFrame, and set the features and label metadata. + * @param data Dataset. Categorical features and labels must already have 0-based indices. + * This must be non-empty. + * @param categoricalFeatures Map: categorical feature index -> number of distinct values + * @param numClasses Number of classes label can take. If 0, mark as continuous. + * @return DataFrame with metadata + */ + def setMetadata( + data: RDD[LabeledPoint], + categoricalFeatures: Map[Int, Int], + numClasses: Int): DataFrame = { + val sqlContext = new SQLContext(data.sparkContext) + import sqlContext.implicits._ + val df = data.toDF() + val numFeatures = data.first().features.size + val featuresAttributes = Range(0, numFeatures).map { feature => + if (categoricalFeatures.contains(feature)) { + NominalAttribute.defaultAttr.withIndex(feature).withNumValues(categoricalFeatures(feature)) + } else { + NumericAttribute.defaultAttr.withIndex(feature) + } + }.toArray + val featuresMetadata = new AttributeGroup("features", featuresAttributes).toMetadata() + val labelAttribute = if (numClasses == 0) { + NumericAttribute.defaultAttr.withName("label") + } else { + NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses) + } + val labelMetadata = labelAttribute.toMetadata() + df.select(df("features").as("features", featuresMetadata), + df("label").as("label", labelMetadata)) + } + + /** Java-friendly version of [[setMetadata()]] */ + def setMetadata( + data: JavaRDD[LabeledPoint], + categoricalFeatures: java.util.Map[java.lang.Integer, java.lang.Integer], + numClasses: Int): DataFrame = { + setMetadata(data.rdd, categoricalFeatures.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + numClasses) + } + + /** + * Check if the two trees are exactly the same. + * Note: I hesitate to override Node.equals since it could cause problems if users + * make mistakes such as creating loops of Nodes. + * If the trees are not equal, this prints the two trees and throws an exception. + */ + def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = { + try { + checkEqual(a.rootNode, b.rootNode) + } catch { + case ex: Exception => + throw new AssertionError("checkEqual failed since the two trees were not identical.\n" + + "TREE A:\n" + a.toDebugString + "\n" + + "TREE B:\n" + b.toDebugString + "\n", ex) + } + } + + /** + * Return true iff the two nodes and their descendants are exactly the same. + * Note: I hesitate to override Node.equals since it could cause problems if users + * make mistakes such as creating loops of Nodes. + */ + private def checkEqual(a: Node, b: Node): Unit = { + assert(a.prediction === b.prediction) + assert(a.impurity === b.impurity) + (a, b) match { + case (aye: InternalNode, bee: InternalNode) => + assert(aye.split === bee.split) + checkEqual(aye.leftChild, bee.leftChild) + checkEqual(aye.rightChild, bee.rightChild) + case (aye: LeafNode, bee: LeafNode) => // do nothing + case _ => + throw new AssertionError("Found mismatched nodes") + } + } + + // TODO: Reinstate after adding ensembles + /** + * Check if the two models are exactly the same. + * If the models are not equal, this throws an exception. + */ + /* + def checkEqual(a: TreeEnsembleModel, b: TreeEnsembleModel): Unit = { + try { + a.getTrees.zip(b.getTrees).foreach { case (treeA, treeB) => + TreeTests.checkEqual(treeA, treeB) + } + assert(a.getTreeWeights === b.getTreeWeights) + } catch { + case ex: Exception => throw new AssertionError( + "checkEqual failed since the two tree ensembles were not identical") + } + } + */ +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala new file mode 100644 index 0000000000..0b40fe33fa --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.scalatest.FunSuite + +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, + DecisionTreeSuite => OldDecisionTreeSuite} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext { + + import DecisionTreeRegressorSuite.compareAPIs + + private var categoricalDataPointsRDD: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + categoricalDataPointsRDD = + sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints()) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// + + test("Regression stump with 3-ary (ordered) categorical features") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3, 1-> 3) + compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) + } + + test("Regression stump with binary (ordered) categorical features") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 2, 1-> 2) + compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: test("model save/load") +} + +private[ml] object DecisionTreeRegressorSuite extends FunSuite { + + /** + * Train 2 decision trees on the given dataset, one using the old API and one using the new API. + * Convert the old tree to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + dt: DecisionTreeRegressor, + categoricalFeatures: Map[Int, Int]): Unit = { + val oldStrategy = dt.getOldStrategy(categoricalFeatures) + val oldTree = OldDecisionTree.train(data, oldStrategy) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) + val newTree = dt.fit(newData) + // Use parent, fittingParamMap from newTree since these are not checked anyways. + val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(oldTree, newTree.parent, + newTree.fittingParamMap, categoricalFeatures) + TreeTests.checkEqual(oldTreeAsNew, newTree) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 4c162df810..249b8eae19 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -36,6 +36,10 @@ import org.apache.spark.util.Utils class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { + ///////////////////////////////////////////////////////////////////////////// + // Tests examining individual elements of training + ///////////////////////////////////////////////////////////////////////////// + test("Binary classification with continuous features: split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) @@ -254,6 +258,165 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { assert(bins(0).length === 0) } + test("Avoid aggregation on the last level") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) + + val topNode = Node.emptyNode(nodeIndex = 1) + assert(topNode.predict.predict === Double.MinValue) + assert(topNode.impurity === -1.0) + assert(topNode.isLeaf === false) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + + // don't enqueue leaf nodes into node queue + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.predict.predict !== Double.MinValue) + assert(topNode.impurity !== -1.0) + + // set impurity and predict for child nodes + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) + assert(topNode.leftNode.get.impurity === 0.0) + assert(topNode.rightNode.get.impurity === 0.0) + } + + test("Avoid aggregation if impurity is 0.0") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) + + val topNode = Node.emptyNode(nodeIndex = 1) + assert(topNode.predict.predict === Double.MinValue) + assert(topNode.impurity === -1.0) + assert(topNode.isLeaf === false) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + + // don't enqueue a node into node queue if its impurity is 0.0 + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.predict.predict !== Double.MinValue) + assert(topNode.impurity !== -1.0) + + // set impurity and predict for child nodes + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) + assert(topNode.leftNode.get.impurity === 0.0) + assert(topNode.rightNode.get.impurity === 0.0) + } + + test("Second level node building with vs. without groups") { + val arr = DecisionTreeSuite.generateOrderedLabeledPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Entropy, 3, 2, 100) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + + // Train a 1-node model + val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1, + numClasses = 2, maxBins = 100) + val modelOneNode = DecisionTree.train(rdd, strategyOneNode) + val rootNode1 = modelOneNode.topNode.deepCopy() + val rootNode2 = modelOneNode.topNode.deepCopy() + assert(rootNode1.leftNode.nonEmpty) + assert(rootNode1.rightNode.nonEmpty) + + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) + + // Single group second level tree construction. + val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get))) + val treeToNodeToIndexInfo = Map((0, Map( + (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)), + (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None))))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + val children1 = new Array[Node](2) + children1(0) = rootNode1.leftNode.get + children1(1) = rootNode1.rightNode.get + + // Train one second-level node at a time. + val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get))) + val treeToNodeToIndexInfoA = Map((0, Map( + (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) + nodeQueue.clear() + DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), + nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue) + val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get))) + val treeToNodeToIndexInfoB = Map((0, Map( + (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) + nodeQueue.clear() + DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), + nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue) + val children2 = new Array[Node](2) + children2(0) = rootNode2.leftNode.get + children2(1) = rootNode2.rightNode.get + + // Verify whether the splits obtained using single group and multiple group level + // construction strategies are the same. + for (i <- 0 until 2) { + assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0) + assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0) + assert(children1(i).split === children2(i).split) + assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty) + val stats1 = children1(i).stats.get + val stats2 = children2(i).stats.get + assert(stats1.gain === stats2.gain) + assert(stats1.impurity === stats2.impurity) + assert(stats1.leftImpurity === stats2.leftImpurity) + assert(stats1.rightImpurity === stats2.rightImpurity) + assert(children1(i).predict.predict === children2(i).predict.predict) + } + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// test("Binary classification stump with ordered categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() @@ -438,76 +601,6 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { assert(rootNode.predict.predict === 1) } - test("Second level node building with vs. without groups") { - val arr = DecisionTreeSuite.generateOrderedLabeledPoints() - assert(arr.length === 1000) - val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(splits(0).length === 99) - assert(bins.length === 2) - assert(bins(0).length === 100) - - // Train a 1-node model - val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1, - numClasses = 2, maxBins = 100) - val modelOneNode = DecisionTree.train(rdd, strategyOneNode) - val rootNode1 = modelOneNode.topNode.deepCopy() - val rootNode2 = modelOneNode.topNode.deepCopy() - assert(rootNode1.leftNode.nonEmpty) - assert(rootNode1.rightNode.nonEmpty) - - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) - - // Single group second level tree construction. - val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get))) - val treeToNodeToIndexInfo = Map((0, Map( - (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)), - (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None))))) - val nodeQueue = new mutable.Queue[(Int, Node)]() - DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1), - nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) - val children1 = new Array[Node](2) - children1(0) = rootNode1.leftNode.get - children1(1) = rootNode1.rightNode.get - - // Train one second-level node at a time. - val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get))) - val treeToNodeToIndexInfoA = Map((0, Map( - (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) - nodeQueue.clear() - DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), - nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue) - val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get))) - val treeToNodeToIndexInfoB = Map((0, Map( - (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) - nodeQueue.clear() - DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), - nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue) - val children2 = new Array[Node](2) - children2(0) = rootNode2.leftNode.get - children2(1) = rootNode2.rightNode.get - - // Verify whether the splits obtained using single group and multiple group level - // construction strategies are the same. - for (i <- 0 until 2) { - assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0) - assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0) - assert(children1(i).split === children2(i).split) - assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty) - val stats1 = children1(i).stats.get - val stats2 = children2(i).stats.get - assert(stats1.gain === stats2.gain) - assert(stats1.impurity === stats2.impurity) - assert(stats1.leftImpurity === stats2.leftImpurity) - assert(stats1.rightImpurity === stats2.rightImpurity) - assert(children1(i).predict.predict === children2(i).predict.predict) - } - } - test("Multiclass classification stump with 3-ary (unordered) categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val rdd = sc.parallelize(arr) @@ -528,11 +621,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } test("Binary classification stump with 1 continuous feature, to check off-by-1 error") { - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0)) - arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0)) - arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0)) - arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0)) + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(3.0))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 2) @@ -544,11 +637,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } test("Binary classification stump with 2 continuous features") { - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) - arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) - arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) - arr(3) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0)))) + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0))))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, @@ -668,11 +761,10 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } test("split must satisfy min instances per node requirements") { - val arr = new Array[LabeledPoint](3) - arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) - arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) - arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))) - + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, numClasses = 2, minInstancesPerNode = 2) @@ -695,11 +787,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { test("do not choose split that does not satisfy min instance per node requirements") { // if a split does not satisfy min instances per node requirements, // this split is invalid, even though the information gain of split is large. - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0)) - arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0, 1.0)) - arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)) - arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)) + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, @@ -715,10 +807,10 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } test("split must satisfy min info gain requirements") { - val arr = new Array[LabeledPoint](3) - arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) - arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) - arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))) + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) val input = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, @@ -739,91 +831,9 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { assert(gain == InformationGainStats.invalidInformationGainStats) } - test("Avoid aggregation on the last level") { - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)) - arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)) - arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)) - arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)) - val input = sc.parallelize(arr) - - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, - numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) - - val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) - - val topNode = Node.emptyNode(nodeIndex = 1) - assert(topNode.predict.predict === Double.MinValue) - assert(topNode.impurity === -1.0) - assert(topNode.isLeaf === false) - - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) - val nodeQueue = new mutable.Queue[(Int, Node)]() - DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), - nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) - - // don't enqueue leaf nodes into node queue - assert(nodeQueue.isEmpty) - - // set impurity and predict for topNode - assert(topNode.predict.predict !== Double.MinValue) - assert(topNode.impurity !== -1.0) - - // set impurity and predict for child nodes - assert(topNode.leftNode.get.predict.predict === 0.0) - assert(topNode.rightNode.get.predict.predict === 1.0) - assert(topNode.leftNode.get.impurity === 0.0) - assert(topNode.rightNode.get.impurity === 0.0) - } - - test("Avoid aggregation if impurity is 0.0") { - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)) - arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)) - arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)) - arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)) - val input = sc.parallelize(arr) - - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, - numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) - - val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) - - val topNode = Node.emptyNode(nodeIndex = 1) - assert(topNode.predict.predict === Double.MinValue) - assert(topNode.impurity === -1.0) - assert(topNode.isLeaf === false) - - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) - val nodeQueue = new mutable.Queue[(Int, Node)]() - DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), - nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) - - // don't enqueue a node into node queue if its impurity is 0.0 - assert(nodeQueue.isEmpty) - - // set impurity and predict for topNode - assert(topNode.predict.predict !== Double.MinValue) - assert(topNode.impurity !== -1.0) - - // set impurity and predict for child nodes - assert(topNode.leftNode.get.predict.predict === 0.0) - assert(topNode.rightNode.get.predict.predict === 1.0) - assert(topNode.leftNode.get.impurity === 0.0) - assert(topNode.rightNode.get.impurity === 0.0) - } + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// test("Node.subtreeIterator") { val model = DecisionTreeSuite.createModel(Classification) @@ -996,8 +1006,9 @@ object DecisionTreeSuite extends FunSuite { /** * Create a tree model. This is deterministic and contains a variety of node and feature types. + * TODO: Update this to be a correct tree (with matching probabilities, impurities, etc.) */ - private[tree] def createModel(algo: Algo): DecisionTreeModel = { + private[mllib] def createModel(algo: Algo): DecisionTreeModel = { val topNode = createInternalNode(id = 1, Continuous) val (node2, node3) = (createLeafNode(id = 2), createInternalNode(id = 3, Categorical)) val (node6, node7) = (createLeafNode(id = 6), createLeafNode(id = 7)) @@ -1017,7 +1028,7 @@ object DecisionTreeSuite extends FunSuite { * make mistakes such as creating loops of Nodes. * If the trees are not equal, this prints the two trees and throws an exception. */ - private[tree] def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = { + private[mllib] def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = { try { assert(a.algo === b.algo) checkEqual(a.topNode, b.topNode) |