aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala43
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala155
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala300
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/package.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala145
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala205
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala151
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala60
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala82
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala32
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java98
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java97
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala42
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala274
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala132
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala91
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala373
32 files changed, 2104 insertions, 263 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
index aa27a668f1..d7dee8fed2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
@@ -117,12 +117,12 @@ class AttributeGroup private (
case numeric: NumericAttribute =>
// Skip default numeric attributes.
if (numeric.withoutIndex != NumericAttribute.defaultAttr) {
- numericMetadata += numeric.toMetadata(withType = false)
+ numericMetadata += numeric.toMetadataImpl(withType = false)
}
case nominal: NominalAttribute =>
- nominalMetadata += nominal.toMetadata(withType = false)
+ nominalMetadata += nominal.toMetadataImpl(withType = false)
case binary: BinaryAttribute =>
- binaryMetadata += binary.toMetadata(withType = false)
+ binaryMetadata += binary.toMetadataImpl(withType = false)
}
val attrBldr = new MetadataBuilder
if (numericMetadata.nonEmpty) {
@@ -151,7 +151,7 @@ class AttributeGroup private (
}
/** Converts to ML metadata */
- def toMetadata: Metadata = toMetadata(Metadata.empty)
+ def toMetadata(): Metadata = toMetadata(Metadata.empty)
/** Converts to a StructField with some existing metadata. */
def toStructField(existingMetadata: Metadata): StructField = {
@@ -159,7 +159,7 @@ class AttributeGroup private (
}
/** Converts to a StructField. */
- def toStructField: StructField = toStructField(Metadata.empty)
+ def toStructField(): StructField = toStructField(Metadata.empty)
override def equals(other: Any): Boolean = {
other match {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
index 00b7566aab..5717d6ec2e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
@@ -68,21 +68,32 @@ sealed abstract class Attribute extends Serializable {
* Converts this attribute to [[Metadata]].
* @param withType whether to include the type info
*/
- private[attribute] def toMetadata(withType: Boolean): Metadata
+ private[attribute] def toMetadataImpl(withType: Boolean): Metadata
/**
* Converts this attribute to [[Metadata]]. For numeric attributes, the type info is excluded to
* save space, because numeric type is the default attribute type. For nominal and binary
* attributes, the type info is included.
*/
- private[attribute] def toMetadata(): Metadata = {
+ private[attribute] def toMetadataImpl(): Metadata = {
if (attrType == AttributeType.Numeric) {
- toMetadata(withType = false)
+ toMetadataImpl(withType = false)
} else {
- toMetadata(withType = true)
+ toMetadataImpl(withType = true)
}
}
+ /** Converts to ML metadata with some existing metadata. */
+ def toMetadata(existingMetadata: Metadata): Metadata = {
+ new MetadataBuilder()
+ .withMetadata(existingMetadata)
+ .putMetadata(AttributeKeys.ML_ATTR, toMetadataImpl())
+ .build()
+ }
+
+ /** Converts to ML metadata */
+ def toMetadata(): Metadata = toMetadata(Metadata.empty)
+
/**
* Converts to a [[StructField]] with some existing metadata.
* @param existingMetadata existing metadata to carry over
@@ -90,7 +101,7 @@ sealed abstract class Attribute extends Serializable {
def toStructField(existingMetadata: Metadata): StructField = {
val newMetadata = new MetadataBuilder()
.withMetadata(existingMetadata)
- .putMetadata(AttributeKeys.ML_ATTR, withoutName.withoutIndex.toMetadata())
+ .putMetadata(AttributeKeys.ML_ATTR, withoutName.withoutIndex.toMetadataImpl())
.build()
StructField(name.get, DoubleType, nullable = false, newMetadata)
}
@@ -98,7 +109,7 @@ sealed abstract class Attribute extends Serializable {
/** Converts to a [[StructField]]. */
def toStructField(): StructField = toStructField(Metadata.empty)
- override def toString: String = toMetadata(withType = true).toString
+ override def toString: String = toMetadataImpl(withType = true).toString
}
/** Trait for ML attribute factories. */
@@ -210,7 +221,7 @@ class NumericAttribute private[ml] (
override def isNominal: Boolean = false
/** Convert this attribute to metadata. */
- private[attribute] override def toMetadata(withType: Boolean): Metadata = {
+ override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = {
import org.apache.spark.ml.attribute.AttributeKeys._
val bldr = new MetadataBuilder()
if (withType) bldr.putString(TYPE, attrType.name)
@@ -353,6 +364,20 @@ class NominalAttribute private[ml] (
/** Copy without the `numValues`. */
def withoutNumValues: NominalAttribute = copy(numValues = None)
+ /**
+ * Get the number of values, either from `numValues` or from `values`.
+ * Return None if unknown.
+ */
+ def getNumValues: Option[Int] = {
+ if (numValues.nonEmpty) {
+ numValues
+ } else if (values.nonEmpty) {
+ Some(values.get.length)
+ } else {
+ None
+ }
+ }
+
/** Creates a copy of this attribute with optional changes. */
private def copy(
name: Option[String] = name,
@@ -363,7 +388,7 @@ class NominalAttribute private[ml] (
new NominalAttribute(name, index, isOrdinal, numValues, values)
}
- private[attribute] override def toMetadata(withType: Boolean): Metadata = {
+ override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = {
import org.apache.spark.ml.attribute.AttributeKeys._
val bldr = new MetadataBuilder()
if (withType) bldr.putString(TYPE, attrType.name)
@@ -465,7 +490,7 @@ class BinaryAttribute private[ml] (
new BinaryAttribute(name, index, values)
}
- private[attribute] override def toMetadata(withType: Boolean): Metadata = {
+ override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = {
import org.apache.spark.ml.attribute.AttributeKeys._
val bldr = new MetadataBuilder
if (withType) bldr.putString(TYPE, attrType.name)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
new file mode 100644
index 0000000000..3855e396b5
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.impl.estimator.{Predictor, PredictionModel}
+import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.param.{Params, ParamMap}
+import org.apache.spark.ml.tree.{DecisionTreeModel, Node}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm
+ * for classification.
+ * It supports both binary and multiclass labels, as well as both continuous and categorical
+ * features.
+ */
+@AlphaComponent
+final class DecisionTreeClassifier
+ extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
+ with DecisionTreeParams
+ with TreeClassifierParams {
+
+ // Override parameter setters from parent trait for Java API compatibility.
+
+ override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+
+ override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+
+ override def setMinInstancesPerNode(value: Int): this.type =
+ super.setMinInstancesPerNode(value)
+
+ override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+
+ override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+
+ override def setCacheNodeIds(value: Boolean): this.type =
+ super.setCacheNodeIds(value)
+
+ override def setCheckpointInterval(value: Int): this.type =
+ super.setCheckpointInterval(value)
+
+ override def setImpurity(value: String): this.type = super.setImpurity(value)
+
+ override protected def train(
+ dataset: DataFrame,
+ paramMap: ParamMap): DecisionTreeClassificationModel = {
+ val categoricalFeatures: Map[Int, Int] =
+ MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
+ val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match {
+ case Some(n: Int) => n
+ case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" +
+ s" with invalid label column, without the number of classes specified.")
+ // TODO: Automatically index labels.
+ }
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ val strategy = getOldStrategy(categoricalFeatures, numClasses)
+ val oldModel = OldDecisionTree.train(oldDataset, strategy)
+ DecisionTreeClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ }
+
+ /** (private[ml]) Create a Strategy instance to use with the old API. */
+ override private[ml] def getOldStrategy(
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int): OldStrategy = {
+ val strategy = super.getOldStrategy(categoricalFeatures, numClasses)
+ strategy.algo = OldAlgo.Classification
+ strategy.setImpurity(getOldImpurity)
+ strategy
+ }
+}
+
+object DecisionTreeClassifier {
+ /** Accessor for supported impurities */
+ final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for classification.
+ * It supports both binary and multiclass labels, as well as both continuous and categorical
+ * features.
+ */
+@AlphaComponent
+final class DecisionTreeClassificationModel private[ml] (
+ override val parent: DecisionTreeClassifier,
+ override val fittingParamMap: ParamMap,
+ override val rootNode: Node)
+ extends PredictionModel[Vector, DecisionTreeClassificationModel]
+ with DecisionTreeModel with Serializable {
+
+ require(rootNode != null,
+ "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
+
+ override protected def predict(features: Vector): Double = {
+ rootNode.predict(features)
+ }
+
+ override protected def copy(): DecisionTreeClassificationModel = {
+ val m = new DecisionTreeClassificationModel(parent, fittingParamMap, rootNode)
+ Params.inheritValues(this.extractParamMap(), this, m)
+ m
+ }
+
+ override def toString: String = {
+ s"DecisionTreeClassificationModel of depth $depth with $numNodes nodes"
+ }
+
+ /** (private[ml]) Convert to a model in the old API */
+ private[ml] def toOld: OldDecisionTreeModel = {
+ new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification)
+ }
+}
+
+private[ml] object DecisionTreeClassificationModel {
+
+ /** (private[ml]) Convert a model from the old API */
+ def fromOld(
+ oldModel: OldDecisionTreeModel,
+ parent: DecisionTreeClassifier,
+ fittingParamMap: ParamMap,
+ categoricalFeatures: Map[Int, Int]): DecisionTreeClassificationModel = {
+ require(oldModel.algo == OldAlgo.Classification,
+ s"Cannot convert non-classification DecisionTreeModel (old API) to" +
+ s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
+ val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
+ new DecisionTreeClassificationModel(parent, fittingParamMap, rootNode)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 4d960df357..23956c512c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -118,7 +118,7 @@ class StringIndexerModel private[ml] (
}
val outputColName = map(outputCol)
val metadata = NominalAttribute.defaultAttr
- .withName(outputColName).withValues(labels).toStructField().metadata
+ .withName(outputColName).withValues(labels).toMetadata()
dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata))
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
new file mode 100644
index 0000000000..6f4509f03d
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
@@ -0,0 +1,300 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.impl.tree
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.ml.impl.estimator.PredictorParams
+import org.apache.spark.ml.param._
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.impurity.{Gini => OldGini, Entropy => OldEntropy,
+ Impurity => OldImpurity, Variance => OldVariance}
+
+
+/**
+ * :: DeveloperApi ::
+ * Parameters for Decision Tree-based algorithms.
+ *
+ * Note: Marked as private and DeveloperApi since this may be made public in the future.
+ */
+@DeveloperApi
+private[ml] trait DecisionTreeParams extends PredictorParams {
+
+ /**
+ * Maximum depth of the tree.
+ * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
+ * (default = 5)
+ * @group param
+ */
+ final val maxDepth: IntParam =
+ new IntParam(this, "maxDepth", "Maximum depth of the tree." +
+ " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.")
+
+ /**
+ * Maximum number of bins used for discretizing continuous features and for choosing how to split
+ * on features at each node. More bins give higher granularity.
+ * Must be >= 2 and >= number of categories in any categorical feature.
+ * (default = 32)
+ * @group param
+ */
+ final val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" +
+ " discretizing continuous features. Must be >=2 and >= number of categories for any" +
+ " categorical feature.")
+
+ /**
+ * Minimum number of instances each child must have after split.
+ * If a split causes the left or right child to have fewer than minInstancesPerNode,
+ * the split will be discarded as invalid.
+ * Should be >= 1.
+ * (default = 1)
+ * @group param
+ */
+ final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" +
+ " number of instances each child must have after split. If a split causes the left or right" +
+ " child to have fewer than minInstancesPerNode, the split will be discarded as invalid." +
+ " Should be >= 1.")
+
+ /**
+ * Minimum information gain for a split to be considered at a tree node.
+ * (default = 0.0)
+ * @group param
+ */
+ final val minInfoGain: DoubleParam = new DoubleParam(this, "minInfoGain",
+ "Minimum information gain for a split to be considered at a tree node.")
+
+ /**
+ * Maximum memory in MB allocated to histogram aggregation.
+ * (default = 256 MB)
+ * @group expertParam
+ */
+ final val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB",
+ "Maximum memory in MB allocated to histogram aggregation.")
+
+ /**
+ * If false, the algorithm will pass trees to executors to match instances with nodes.
+ * If true, the algorithm will cache node IDs for each instance.
+ * Caching can speed up training of deeper trees.
+ * (default = false)
+ * @group expertParam
+ */
+ final val cacheNodeIds: BooleanParam = new BooleanParam(this, "cacheNodeIds", "If false, the" +
+ " algorithm will pass trees to executors to match instances with nodes. If true, the" +
+ " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" +
+ " trees.")
+
+ /**
+ * Specifies how often to checkpoint the cached node IDs.
+ * E.g. 10 means that the cache will get checkpointed every 10 iterations.
+ * This is only used if cacheNodeIds is true and if the checkpoint directory is set in
+ * [[org.apache.spark.SparkContext]].
+ * Must be >= 1.
+ * (default = 10)
+ * @group expertParam
+ */
+ final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" +
+ " how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get" +
+ " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" +
+ " checkpoint directory is set in the SparkContext. Must be >= 1.")
+
+ setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0,
+ maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10)
+
+ /** @group setParam */
+ def setMaxDepth(value: Int): this.type = {
+ require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value")
+ set(maxDepth, value)
+ this.asInstanceOf[this.type]
+ }
+
+ /** @group getParam */
+ def getMaxDepth: Int = getOrDefault(maxDepth)
+
+ /** @group setParam */
+ def setMaxBins(value: Int): this.type = {
+ require(value >= 2, s"maxBins parameter must be >= 2. Given bad value: $value")
+ set(maxBins, value)
+ this
+ }
+
+ /** @group getParam */
+ def getMaxBins: Int = getOrDefault(maxBins)
+
+ /** @group setParam */
+ def setMinInstancesPerNode(value: Int): this.type = {
+ require(value >= 1, s"minInstancesPerNode parameter must be >= 1. Given bad value: $value")
+ set(minInstancesPerNode, value)
+ this
+ }
+
+ /** @group getParam */
+ def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode)
+
+ /** @group setParam */
+ def setMinInfoGain(value: Double): this.type = {
+ set(minInfoGain, value)
+ this
+ }
+
+ /** @group getParam */
+ def getMinInfoGain: Double = getOrDefault(minInfoGain)
+
+ /** @group expertSetParam */
+ def setMaxMemoryInMB(value: Int): this.type = {
+ require(value > 0, s"maxMemoryInMB parameter must be > 0. Given bad value: $value")
+ set(maxMemoryInMB, value)
+ this
+ }
+
+ /** @group expertGetParam */
+ def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB)
+
+ /** @group expertSetParam */
+ def setCacheNodeIds(value: Boolean): this.type = {
+ set(cacheNodeIds, value)
+ this
+ }
+
+ /** @group expertGetParam */
+ def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds)
+
+ /** @group expertSetParam */
+ def setCheckpointInterval(value: Int): this.type = {
+ require(value >= 1, s"checkpointInterval parameter must be >= 1. Given bad value: $value")
+ set(checkpointInterval, value)
+ this
+ }
+
+ /** @group expertGetParam */
+ def getCheckpointInterval: Int = getOrDefault(checkpointInterval)
+
+ /**
+ * Create a Strategy instance to use with the old API.
+ * NOTE: The caller should set impurity and subsamplingRate (which is set to 1.0,
+ * the default for single trees).
+ */
+ private[ml] def getOldStrategy(
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int): OldStrategy = {
+ val strategy = OldStrategy.defaultStategy(OldAlgo.Classification)
+ strategy.checkpointInterval = getCheckpointInterval
+ strategy.maxBins = getMaxBins
+ strategy.maxDepth = getMaxDepth
+ strategy.maxMemoryInMB = getMaxMemoryInMB
+ strategy.minInfoGain = getMinInfoGain
+ strategy.minInstancesPerNode = getMinInstancesPerNode
+ strategy.useNodeIdCache = getCacheNodeIds
+ strategy.numClasses = numClasses
+ strategy.categoricalFeaturesInfo = categoricalFeatures
+ strategy.subsamplingRate = 1.0 // default for individual trees
+ strategy
+ }
+}
+
+/**
+ * (private trait) Parameters for Decision Tree-based classification algorithms.
+ */
+private[ml] trait TreeClassifierParams extends Params {
+
+ /**
+ * Criterion used for information gain calculation (case-insensitive).
+ * Supported: "entropy" and "gini".
+ * (default = gini)
+ * @group param
+ */
+ val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
+ " information gain calculation (case-insensitive). Supported options:" +
+ s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}")
+
+ setDefault(impurity -> "gini")
+
+ /** @group setParam */
+ def setImpurity(value: String): this.type = {
+ val impurityStr = value.toLowerCase
+ require(TreeClassifierParams.supportedImpurities.contains(impurityStr),
+ s"Tree-based classifier was given unrecognized impurity: $value." +
+ s" Supported options: ${TreeClassifierParams.supportedImpurities.mkString(", ")}")
+ set(impurity, impurityStr)
+ this
+ }
+
+ /** @group getParam */
+ def getImpurity: String = getOrDefault(impurity)
+
+ /** Convert new impurity to old impurity. */
+ private[ml] def getOldImpurity: OldImpurity = {
+ getImpurity match {
+ case "entropy" => OldEntropy
+ case "gini" => OldGini
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(
+ s"TreeClassifierParams was given unrecognized impurity: $impurity.")
+ }
+ }
+}
+
+private[ml] object TreeClassifierParams {
+ // These options should be lowercase.
+ val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase)
+}
+
+/**
+ * (private trait) Parameters for Decision Tree-based regression algorithms.
+ */
+private[ml] trait TreeRegressorParams extends Params {
+
+ /**
+ * Criterion used for information gain calculation (case-insensitive).
+ * Supported: "variance".
+ * (default = variance)
+ * @group param
+ */
+ val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
+ " information gain calculation (case-insensitive). Supported options:" +
+ s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}")
+
+ setDefault(impurity -> "variance")
+
+ /** @group setParam */
+ def setImpurity(value: String): this.type = {
+ val impurityStr = value.toLowerCase
+ require(TreeRegressorParams.supportedImpurities.contains(impurityStr),
+ s"Tree-based regressor was given unrecognized impurity: $value." +
+ s" Supported options: ${TreeRegressorParams.supportedImpurities.mkString(", ")}")
+ set(impurity, impurityStr)
+ this
+ }
+
+ /** @group getParam */
+ def getImpurity: String = getOrDefault(impurity)
+
+ /** Convert new impurity to old impurity. */
+ protected def getOldImpurity: OldImpurity = {
+ getImpurity match {
+ case "variance" => OldVariance
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(
+ s"TreeRegressorParams was given unrecognized impurity: $impurity")
+ }
+ }
+}
+
+private[ml] object TreeRegressorParams {
+ // These options should be lowercase.
+ val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala
index b45bd1499b..ac75e9de1a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/package.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala
@@ -32,6 +32,18 @@ package org.apache.spark
* @groupname getParam Parameter getters
* @groupprio getParam 6
*
+ * @groupname expertParam (expert-only) Parameters
+ * @groupdesc expertParam A list of advanced, expert-only (hyper-)parameter keys this algorithm can
+ * take. Users can set and get the parameter values through setters and getters,
+ * respectively.
+ * @groupprio expertParam 7
+ *
+ * @groupname expertSetParam (expert-only) Parameter setters
+ * @groupprio expertSetParam 8
+ *
+ * @groupname expertGetParam (expert-only) Parameter getters
+ * @groupprio expertGetParam 9
+ *
* @groupname Ungrouped Members
* @groupprio Ungrouped 0
*/
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 849c60433c..ddc5907e7f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -296,8 +296,9 @@ private[spark] object Params {
paramMap: ParamMap,
parent: E,
child: M): Unit = {
+ val childParams = child.params.map(_.name).toSet
parent.params.foreach { param =>
- if (paramMap.contains(param)) {
+ if (paramMap.contains(param) && childParams.contains(param.name)) {
child.set(child.getParam(param.name), paramMap(param))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
new file mode 100644
index 0000000000..49a8b77acf
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
+import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.param.{Params, ParamMap}
+import org.apache.spark.ml.tree.{DecisionTreeModel, Node}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm
+ * for regression.
+ * It supports both continuous and categorical features.
+ */
+@AlphaComponent
+final class DecisionTreeRegressor
+ extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
+ with DecisionTreeParams
+ with TreeRegressorParams {
+
+ // Override parameter setters from parent trait for Java API compatibility.
+
+ override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+
+ override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+
+ override def setMinInstancesPerNode(value: Int): this.type =
+ super.setMinInstancesPerNode(value)
+
+ override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+
+ override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+
+ override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
+
+ override def setCheckpointInterval(value: Int): this.type =
+ super.setCheckpointInterval(value)
+
+ override def setImpurity(value: String): this.type = super.setImpurity(value)
+
+ override protected def train(
+ dataset: DataFrame,
+ paramMap: ParamMap): DecisionTreeRegressionModel = {
+ val categoricalFeatures: Map[Int, Int] =
+ MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ val strategy = getOldStrategy(categoricalFeatures)
+ val oldModel = OldDecisionTree.train(oldDataset, strategy)
+ DecisionTreeRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ }
+
+ /** (private[ml]) Create a Strategy instance to use with the old API. */
+ private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = {
+ val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0)
+ strategy.algo = OldAlgo.Regression
+ strategy.setImpurity(getOldImpurity)
+ strategy
+ }
+}
+
+object DecisionTreeRegressor {
+ /** Accessor for supported impurities */
+ final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for regression.
+ * It supports both continuous and categorical features.
+ * @param rootNode Root of the decision tree
+ */
+@AlphaComponent
+final class DecisionTreeRegressionModel private[ml] (
+ override val parent: DecisionTreeRegressor,
+ override val fittingParamMap: ParamMap,
+ override val rootNode: Node)
+ extends PredictionModel[Vector, DecisionTreeRegressionModel]
+ with DecisionTreeModel with Serializable {
+
+ require(rootNode != null,
+ "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
+
+ override protected def predict(features: Vector): Double = {
+ rootNode.predict(features)
+ }
+
+ override protected def copy(): DecisionTreeRegressionModel = {
+ val m = new DecisionTreeRegressionModel(parent, fittingParamMap, rootNode)
+ Params.inheritValues(this.extractParamMap(), this, m)
+ m
+ }
+
+ override def toString: String = {
+ s"DecisionTreeRegressionModel of depth $depth with $numNodes nodes"
+ }
+
+ /** Convert to a model in the old API */
+ private[ml] def toOld: OldDecisionTreeModel = {
+ new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression)
+ }
+}
+
+private[ml] object DecisionTreeRegressionModel {
+
+ /** (private[ml]) Convert a model from the old API */
+ def fromOld(
+ oldModel: OldDecisionTreeModel,
+ parent: DecisionTreeRegressor,
+ fittingParamMap: ParamMap,
+ categoricalFeatures: Map[Int, Int]): DecisionTreeRegressionModel = {
+ require(oldModel.algo == OldAlgo.Regression,
+ s"Cannot convert non-regression DecisionTreeModel (old API) to" +
+ s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}")
+ val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
+ new DecisionTreeRegressionModel(parent, fittingParamMap, rootNode)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
new file mode 100644
index 0000000000..d6e2203d9f
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree
+
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats,
+ Node => OldNode, Predict => OldPredict}
+
+
+/**
+ * Decision tree node interface.
+ */
+sealed abstract class Node extends Serializable {
+
+ // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree
+ // code into the new API and deprecate the old API.
+
+ /** Prediction this node makes (or would make, if it is an internal node) */
+ def prediction: Double
+
+ /** Impurity measure at this node (for training data) */
+ def impurity: Double
+
+ /** Recursive prediction helper method */
+ private[ml] def predict(features: Vector): Double = prediction
+
+ /**
+ * Get the number of nodes in tree below this node, including leaf nodes.
+ * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2.
+ */
+ private[tree] def numDescendants: Int
+
+ /**
+ * Recursive print function.
+ * @param indentFactor The number of spaces to add to each level of indentation.
+ */
+ private[tree] def subtreeToString(indentFactor: Int = 0): String
+
+ /**
+ * Get depth of tree from this node.
+ * E.g.: Depth 0 means this is a leaf node. Depth 1 means 1 internal and 2 leaf nodes.
+ */
+ private[tree] def subtreeDepth: Int
+
+ /**
+ * Create a copy of this node in the old Node format, recursively creating child nodes as needed.
+ * @param id Node ID using old format IDs
+ */
+ private[ml] def toOld(id: Int): OldNode
+}
+
+private[ml] object Node {
+
+ /**
+ * Create a new Node from the old Node format, recursively creating child nodes as needed.
+ */
+ def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = {
+ if (oldNode.isLeaf) {
+ // TODO: Once the implementation has been moved to this API, then include sufficient
+ // statistics here.
+ new LeafNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity)
+ } else {
+ val gain = if (oldNode.stats.nonEmpty) {
+ oldNode.stats.get.gain
+ } else {
+ 0.0
+ }
+ new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity,
+ gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures),
+ rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
+ split = Split.fromOld(oldNode.split.get, categoricalFeatures))
+ }
+ }
+}
+
+/**
+ * Decision tree leaf node.
+ * @param prediction Prediction this node makes
+ * @param impurity Impurity measure at this node (for training data)
+ */
+final class LeafNode private[ml] (
+ override val prediction: Double,
+ override val impurity: Double) extends Node {
+
+ override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)"
+
+ override private[ml] def predict(features: Vector): Double = prediction
+
+ override private[tree] def numDescendants: Int = 0
+
+ override private[tree] def subtreeToString(indentFactor: Int = 0): String = {
+ val prefix: String = " " * indentFactor
+ prefix + s"Predict: $prediction\n"
+ }
+
+ override private[tree] def subtreeDepth: Int = 0
+
+ override private[ml] def toOld(id: Int): OldNode = {
+ // NOTE: We do NOT store 'prob' in the new API currently.
+ new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = true,
+ None, None, None, None)
+ }
+}
+
+/**
+ * Internal Decision Tree node.
+ * @param prediction Prediction this node would make if it were a leaf node
+ * @param impurity Impurity measure at this node (for training data)
+ * @param gain Information gain value.
+ * Values < 0 indicate missing values; this quirk will be removed with future updates.
+ * @param leftChild Left-hand child node
+ * @param rightChild Right-hand child node
+ * @param split Information about the test used to split to the left or right child.
+ */
+final class InternalNode private[ml] (
+ override val prediction: Double,
+ override val impurity: Double,
+ val gain: Double,
+ val leftChild: Node,
+ val rightChild: Node,
+ val split: Split) extends Node {
+
+ override def toString: String = {
+ s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)"
+ }
+
+ override private[ml] def predict(features: Vector): Double = {
+ if (split.shouldGoLeft(features)) {
+ leftChild.predict(features)
+ } else {
+ rightChild.predict(features)
+ }
+ }
+
+ override private[tree] def numDescendants: Int = {
+ 2 + leftChild.numDescendants + rightChild.numDescendants
+ }
+
+ override private[tree] def subtreeToString(indentFactor: Int = 0): String = {
+ val prefix: String = " " * indentFactor
+ prefix + s"If (${InternalNode.splitToString(split, left=true)})\n" +
+ leftChild.subtreeToString(indentFactor + 1) +
+ prefix + s"Else (${InternalNode.splitToString(split, left=false)})\n" +
+ rightChild.subtreeToString(indentFactor + 1)
+ }
+
+ override private[tree] def subtreeDepth: Int = {
+ 1 + math.max(leftChild.subtreeDepth, rightChild.subtreeDepth)
+ }
+
+ override private[ml] def toOld(id: Int): OldNode = {
+ assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API"
+ + " since the old API does not support deep trees.")
+ // NOTE: We do NOT store 'prob' in the new API currently.
+ new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = false,
+ Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))),
+ Some(rightChild.toOld(OldNode.rightChildIndex(id))),
+ Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity,
+ new OldPredict(leftChild.prediction, prob = 0.0),
+ new OldPredict(rightChild.prediction, prob = 0.0))))
+ }
+}
+
+private object InternalNode {
+
+ /**
+ * Helper method for [[Node.subtreeToString()]].
+ * @param split Split to print
+ * @param left Indicates whether this is the part of the split going to the left,
+ * or that going to the right.
+ */
+ private def splitToString(split: Split, left: Boolean): String = {
+ val featureStr = s"feature ${split.featureIndex}"
+ split match {
+ case contSplit: ContinuousSplit =>
+ if (left) {
+ s"$featureStr <= ${contSplit.threshold}"
+ } else {
+ s"$featureStr > ${contSplit.threshold}"
+ }
+ case catSplit: CategoricalSplit =>
+ val categoriesStr = catSplit.getLeftCategories.mkString("{", ",", "}")
+ if (left) {
+ s"$featureStr in $categoriesStr"
+ } else {
+ s"$featureStr not in $categoriesStr"
+ }
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
new file mode 100644
index 0000000000..cb940f6299
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree
+
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType}
+import org.apache.spark.mllib.tree.model.{Split => OldSplit}
+
+
+/**
+ * Interface for a "Split," which specifies a test made at a decision tree node
+ * to choose the left or right path.
+ */
+sealed trait Split extends Serializable {
+
+ /** Index of feature which this split tests */
+ def featureIndex: Int
+
+ /** Return true (split to left) or false (split to right) */
+ private[ml] def shouldGoLeft(features: Vector): Boolean
+
+ /** Convert to old Split format */
+ private[tree] def toOld: OldSplit
+}
+
+private[ml] object Split {
+
+ def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = {
+ oldSplit.featureType match {
+ case OldFeatureType.Categorical =>
+ new CategoricalSplit(featureIndex = oldSplit.feature,
+ leftCategories = oldSplit.categories.toArray, categoricalFeatures(oldSplit.feature))
+ case OldFeatureType.Continuous =>
+ new ContinuousSplit(featureIndex = oldSplit.feature, threshold = oldSplit.threshold)
+ }
+ }
+}
+
+/**
+ * Split which tests a categorical feature.
+ * @param featureIndex Index of the feature to test
+ * @param leftCategories If the feature value is in this set of categories, then the split goes
+ * left. Otherwise, it goes right.
+ * @param numCategories Number of categories for this feature.
+ */
+final class CategoricalSplit(
+ override val featureIndex: Int,
+ leftCategories: Array[Double],
+ private val numCategories: Int)
+ extends Split {
+
+ require(leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" +
+ s" (should be in range [0, $numCategories)): ${leftCategories.mkString(",")}")
+
+ /**
+ * If true, then "categories" is the set of categories for splitting to the left, and vice versa.
+ */
+ private val isLeft: Boolean = leftCategories.length <= numCategories / 2
+
+ /** Set of categories determining the splitting rule, along with [[isLeft]]. */
+ private val categories: Set[Double] = {
+ if (isLeft) {
+ leftCategories.toSet
+ } else {
+ setComplement(leftCategories.toSet)
+ }
+ }
+
+ override private[ml] def shouldGoLeft(features: Vector): Boolean = {
+ if (isLeft) {
+ categories.contains(features(featureIndex))
+ } else {
+ !categories.contains(features(featureIndex))
+ }
+ }
+
+ override def equals(o: Any): Boolean = {
+ o match {
+ case other: CategoricalSplit => featureIndex == other.featureIndex &&
+ isLeft == other.isLeft && categories == other.categories
+ case _ => false
+ }
+ }
+
+ override private[tree] def toOld: OldSplit = {
+ val oldCats = if (isLeft) {
+ categories
+ } else {
+ setComplement(categories)
+ }
+ OldSplit(featureIndex, threshold = 0.0, OldFeatureType.Categorical, oldCats.toList)
+ }
+
+ /** Get sorted categories which split to the left */
+ def getLeftCategories: Array[Double] = {
+ val cats = if (isLeft) categories else setComplement(categories)
+ cats.toArray.sorted
+ }
+
+ /** Get sorted categories which split to the right */
+ def getRightCategories: Array[Double] = {
+ val cats = if (isLeft) setComplement(categories) else categories
+ cats.toArray.sorted
+ }
+
+ /** [0, numCategories) \ cats */
+ private def setComplement(cats: Set[Double]): Set[Double] = {
+ Range(0, numCategories).map(_.toDouble).filter(cat => !cats.contains(cat)).toSet
+ }
+}
+
+/**
+ * Split which tests a continuous feature.
+ * @param featureIndex Index of the feature to test
+ * @param threshold If the feature value is <= this threshold, then the split goes left.
+ * Otherwise, it goes right.
+ */
+final class ContinuousSplit(override val featureIndex: Int, val threshold: Double) extends Split {
+
+ override private[ml] def shouldGoLeft(features: Vector): Boolean = {
+ features(featureIndex) <= threshold
+ }
+
+ override def equals(o: Any): Boolean = {
+ o match {
+ case other: ContinuousSplit =>
+ featureIndex == other.featureIndex && threshold == other.threshold
+ case _ =>
+ false
+ }
+ }
+
+ override private[tree] def toOld: OldSplit = {
+ OldSplit(featureIndex, threshold, OldFeatureType.Continuous, List.empty[Double])
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
new file mode 100644
index 0000000000..8e3bc3849d
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree
+
+import org.apache.spark.annotation.AlphaComponent
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Abstraction for Decision Tree models.
+ *
+ * TODO: Add support for predicting probabilities and raw predictions
+ */
+@AlphaComponent
+trait DecisionTreeModel {
+
+ /** Root of the decision tree */
+ def rootNode: Node
+
+ /** Number of nodes in tree, including leaf nodes. */
+ def numNodes: Int = {
+ 1 + rootNode.numDescendants
+ }
+
+ /**
+ * Depth of the tree.
+ * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes.
+ */
+ lazy val depth: Int = {
+ rootNode.subtreeDepth
+ }
+
+ /** Summary of the model */
+ override def toString: String = {
+ // Implementing classes should generally override this method to be more descriptive.
+ s"DecisionTreeModel of depth $depth with $numNodes nodes"
+ }
+
+ /** Full description of model */
+ def toDebugString: String = {
+ val header = toString + "\n"
+ header + rootNode.subtreeToString(2)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
new file mode 100644
index 0000000000..c84c8b4eb7
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import scala.collection.immutable.HashMap
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, BinaryAttribute, NominalAttribute,
+ NumericAttribute}
+import org.apache.spark.sql.types.StructField
+
+
+/**
+ * :: Experimental ::
+ *
+ * Helper utilities for tree-based algorithms
+ */
+@Experimental
+object MetadataUtils {
+
+ /**
+ * Examine a schema to identify the number of classes in a label column.
+ * Returns None if the number of labels is not specified, or if the label column is continuous.
+ */
+ def getNumClasses(labelSchema: StructField): Option[Int] = {
+ Attribute.fromStructField(labelSchema) match {
+ case numAttr: NumericAttribute => None
+ case binAttr: BinaryAttribute => Some(2)
+ case nomAttr: NominalAttribute => nomAttr.getNumValues
+ }
+ }
+
+ /**
+ * Examine a schema to identify categorical (Binary and Nominal) features.
+ *
+ * @param featuresSchema Schema of the features column.
+ * If a feature does not have metadata, it is assumed to be continuous.
+ * If a feature is Nominal, then it must have the number of values
+ * specified.
+ * @return Map: feature index --> number of categories.
+ * The map's set of keys will be the set of categorical feature indices.
+ */
+ def getCategoricalFeatures(featuresSchema: StructField): Map[Int, Int] = {
+ val metadata = AttributeGroup.fromStructField(featuresSchema)
+ if (metadata.attributes.isEmpty) {
+ HashMap.empty[Int, Int]
+ } else {
+ metadata.attributes.get.zipWithIndex.flatMap { case (attr, idx) =>
+ if (attr == null) {
+ Iterator()
+ } else {
+ attr match {
+ case numAttr: NumericAttribute => Iterator()
+ case binAttr: BinaryAttribute => Iterator(idx -> 2)
+ case nomAttr: NominalAttribute =>
+ nomAttr.getNumValues match {
+ case Some(numValues: Int) => Iterator(idx -> numValues)
+ case None => throw new IllegalArgumentException(s"Feature $idx is marked as" +
+ " Nominal (categorical), but it does not have the number of values specified.")
+ }
+ }
+ }
+ }.toMap
+ }
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index b9d0c56dd1..dfe3a0b691 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -1147,7 +1147,10 @@ object DecisionTree extends Serializable with Logging {
}
}
- assert(splits.length > 0)
+ // TODO: Do not fail; just ignore the useless feature.
+ assert(splits.length > 0,
+ s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
+ " Please remove this feature and then try again.")
// set number of splits accordingly
metadata.setNumSplits(featureIndex, splits.length)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index c02c79f094..0e31c7ed58 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -81,11 +81,11 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
/**
* Method to validate a gradient boosting model
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * @param validationInput Validation dataset:
- RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- Should be different from and follow the same distribution as input.
- e.g., these two datasets could be created from an original dataset
- by using [[org.apache.spark.rdd.RDD.randomSplit()]]
+ * @param validationInput Validation dataset.
+ * This dataset should be different from the training dataset,
+ * but it should follow the same distribution.
+ * E.g., these two datasets could be created from an original dataset
+ * by using [[org.apache.spark.rdd.RDD.randomSplit()]]
* @return a gradient boosted trees model that can be used for prediction
*/
def runWithValidation(
@@ -194,8 +194,6 @@ object GradientBoostedTrees extends Logging {
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = firstTreeWeight
- val startingModel = new GradientBoostedTreesModel(
- Regression, Array(firstTreeModel), baseLearnerWeights.slice(0, 1))
var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index db01f2e229..055e60c7d9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -249,7 +249,7 @@ private class RandomForest (
nodeIdCache.get.deleteAllCheckpoints()
} catch {
case e:IOException =>
- logWarning(s"delete all chackpoints failed. Error reason: ${e.getMessage}")
+ logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}")
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index 664c8df019..2d6b01524f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -89,14 +89,14 @@ object BoostingStrategy {
* @return Configuration for boosting algorithm
*/
def defaultParams(algo: Algo): BoostingStrategy = {
- val treeStragtegy = Strategy.defaultStategy(algo)
- treeStragtegy.maxDepth = 3
+ val treeStrategy = Strategy.defaultStategy(algo)
+ treeStrategy.maxDepth = 3
algo match {
case Algo.Classification =>
- treeStragtegy.numClasses = 2
- new BoostingStrategy(treeStragtegy, LogLoss)
+ treeStrategy.numClasses = 2
+ new BoostingStrategy(treeStrategy, LogLoss)
case Algo.Regression =>
- new BoostingStrategy(treeStragtegy, SquaredError)
+ new BoostingStrategy(treeStrategy, SquaredError)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by boosting.")
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
index 6f570b4e09..2bdef73c4a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
-import org.apache.spark.rdd.RDD
+
/**
* :: DeveloperApi ::
@@ -45,9 +45,8 @@ object AbsoluteError extends Loss {
if (label - prediction < 0) 1.0 else -1.0
}
- override def computeError(prediction: Double, label: Double): Double = {
+ override private[mllib] def computeError(prediction: Double, label: Double): Double = {
val err = label - prediction
math.abs(err)
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
index 24ee9f3d51..778c24526d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.rdd.RDD
+
/**
* :: DeveloperApi ::
@@ -47,10 +47,9 @@ object LogLoss extends Loss {
- 4.0 * label / (1.0 + math.exp(2.0 * label * prediction))
}
- override def computeError(prediction: Double, label: Double): Double = {
+ override private[mllib] def computeError(prediction: Double, label: Double): Double = {
val margin = 2.0 * label * prediction
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
2.0 * MLUtils.log1pExp(-margin)
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
index d3b82b752f..64ffccbce0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
@@ -22,6 +22,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
import org.apache.spark.rdd.RDD
+
/**
* :: DeveloperApi ::
* Trait for adding "pluggable" loss functions for the gradient boosting algorithm.
@@ -57,6 +58,5 @@ trait Loss extends Serializable {
* @param label True label.
* @return Measure of model error on datapoint.
*/
- def computeError(prediction: Double, label: Double): Double
-
+ private[mllib] def computeError(prediction: Double, label: Double): Double
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
index 58857ae15e..a5582d3ef3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
-import org.apache.spark.rdd.RDD
+
/**
* :: DeveloperApi ::
@@ -45,9 +45,8 @@ object SquaredError extends Loss {
2.0 * (prediction - label)
}
- override def computeError(prediction: Double, label: Double): Double = {
+ override private[mllib] def computeError(prediction: Double, label: Double): Double = {
val err = prediction - label
err * err
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index c9bafd60fb..331af42853 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -113,11 +113,13 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
DecisionTreeModel.SaveLoadV1_0.save(sc, path, this)
}
- override protected def formatVersion: String = "1.0"
+ override protected def formatVersion: String = DecisionTreeModel.formatVersion
}
object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
+ private[spark] def formatVersion: String = "1.0"
+
private[tree] object SaveLoadV1_0 {
def thisFormatVersion: String = "1.0"
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 4f72bb8014..708ba04b56 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -175,7 +175,7 @@ class Node (
}
}
-private[tree] object Node {
+private[spark] object Node {
/**
* Return a node with the given node id (but nothing else set).
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index fef3d2acb2..8341219bfa 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -38,6 +38,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.util.Utils
+
/**
* :: Experimental ::
* Represents a random forest model.
@@ -47,7 +48,7 @@ import org.apache.spark.util.Utils
*/
@Experimental
class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel])
- extends TreeEnsembleModel(algo, trees, Array.fill(trees.size)(1.0),
+ extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0),
combiningStrategy = if (algo == Classification) Vote else Average)
with Saveable {
@@ -58,11 +59,13 @@ class RandomForestModel(override val algo: Algo, override val trees: Array[Decis
RandomForestModel.SaveLoadV1_0.thisClassName)
}
- override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
+ override protected def formatVersion: String = RandomForestModel.formatVersion
}
object RandomForestModel extends Loader[RandomForestModel] {
+ private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
+
override def load(sc: SparkContext, path: String): RandomForestModel = {
val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
@@ -102,15 +105,13 @@ class GradientBoostedTreesModel(
extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum)
with Saveable {
- require(trees.size == treeWeights.size)
+ require(trees.length == treeWeights.length)
override def save(sc: SparkContext, path: String): Unit = {
TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this,
GradientBoostedTreesModel.SaveLoadV1_0.thisClassName)
}
- override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
-
/**
* Method to compute error or loss for every iteration of gradient boosting.
* @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
@@ -138,7 +139,7 @@ class GradientBoostedTreesModel(
evaluationArray(0) = predictionAndError.values.mean()
val broadcastTrees = sc.broadcast(trees)
- (1 until numIterations).map { nTree =>
+ (1 until numIterations).foreach { nTree =>
predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
val currentTree = broadcastTrees.value(nTree)
val currentTreeWeight = localTreeWeights(nTree)
@@ -155,6 +156,7 @@ class GradientBoostedTreesModel(
evaluationArray
}
+ override protected def formatVersion: String = GradientBoostedTreesModel.formatVersion
}
object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
@@ -200,17 +202,17 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
loss: Loss): RDD[(Double, Double)] = {
val newPredError = data.zip(predictionAndError).mapPartitions { iter =>
- iter.map {
- case (lp, (pred, error)) => {
- val newPred = pred + tree.predict(lp.features) * treeWeight
- val newError = loss.computeError(newPred, lp.label)
- (newPred, newError)
- }
+ iter.map { case (lp, (pred, error)) =>
+ val newPred = pred + tree.predict(lp.features) * treeWeight
+ val newError = loss.computeError(newPred, lp.label)
+ (newPred, newError)
}
}
newPredError
}
+ private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
+
override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = {
val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
@@ -340,12 +342,12 @@ private[tree] sealed class TreeEnsembleModel(
}
/**
- * Get number of trees in forest.
+ * Get number of trees in ensemble.
*/
- def numTrees: Int = trees.size
+ def numTrees: Int = trees.length
/**
- * Get total number of nodes, summed over all trees in the forest.
+ * Get total number of nodes, summed over all trees in the ensemble.
*/
def totalNumNodes: Int = trees.map(_.numNodes).sum
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
new file mode 100644
index 0000000000..43b8787f9d
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification;
+
+import java.io.File;
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.util.Utils;
+
+
+public class JavaDecisionTreeClassifierSuite implements Serializable {
+
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaDecisionTreeClassifierSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runDT() {
+ int nPoints = 20;
+ double A = 2.0;
+ double B = -1.5;
+
+ JavaRDD<LabeledPoint> data = sc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
+ DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+
+ // This tests setters. Training with various options is tested in Scala.
+ DecisionTreeClassifier dt = new DecisionTreeClassifier()
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setMinInstancesPerNode(5)
+ .setMinInfoGain(0.0)
+ .setMaxMemoryInMB(256)
+ .setCacheNodeIds(false)
+ .setCheckpointInterval(10)
+ .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+ for (int i = 0; i < DecisionTreeClassifier.supportedImpurities().length; ++i) {
+ dt.setImpurity(DecisionTreeClassifier.supportedImpurities()[i]);
+ }
+ DecisionTreeClassificationModel model = dt.fit(dataFrame);
+
+ model.transform(dataFrame);
+ model.numNodes();
+ model.depth();
+ model.toDebugString();
+
+ /*
+ // TODO: Add test once save/load are implemented.
+ File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
+ String path = tempDir.toURI().toString();
+ try {
+ model3.save(sc.sc(), path);
+ DecisionTreeClassificationModel sameModel =
+ DecisionTreeClassificationModel.load(sc.sc(), path);
+ TreeTests.checkEqual(model3, sameModel);
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+ */
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
new file mode 100644
index 0000000000..a3a339004f
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression;
+
+import java.io.File;
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.util.Utils;
+
+
+public class JavaDecisionTreeRegressorSuite implements Serializable {
+
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaDecisionTreeRegressorSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runDT() {
+ int nPoints = 20;
+ double A = 2.0;
+ double B = -1.5;
+
+ JavaRDD<LabeledPoint> data = sc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ Map<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
+ DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+
+ // This tests setters. Training with various options is tested in Scala.
+ DecisionTreeRegressor dt = new DecisionTreeRegressor()
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setMinInstancesPerNode(5)
+ .setMinInfoGain(0.0)
+ .setMaxMemoryInMB(256)
+ .setCacheNodeIds(false)
+ .setCheckpointInterval(10)
+ .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+ for (int i = 0; i < DecisionTreeRegressor.supportedImpurities().length; ++i) {
+ dt.setImpurity(DecisionTreeRegressor.supportedImpurities()[i]);
+ }
+ DecisionTreeRegressionModel model = dt.fit(dataFrame);
+
+ model.transform(dataFrame);
+ model.numNodes();
+ model.depth();
+ model.toDebugString();
+
+ /*
+ // TODO: Add test once save/load are implemented.
+ File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
+ String path = tempDir.toURI().toString();
+ try {
+ model2.save(sc.sc(), path);
+ DecisionTreeRegressionModel sameModel = DecisionTreeRegressionModel.load(sc.sc(), path);
+ TreeTests.checkEqual(model2, sameModel);
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+ */
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
index 0dcfe5a200..17ddd335de 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
@@ -44,7 +44,7 @@ class AttributeGroupSuite extends FunSuite {
group("abc")
}
assert(group === AttributeGroup.fromMetadata(group.toMetadataImpl, group.name))
- assert(group === AttributeGroup.fromStructField(group.toStructField))
+ assert(group === AttributeGroup.fromStructField(group.toStructField()))
}
test("attribute group without attributes") {
@@ -54,7 +54,7 @@ class AttributeGroupSuite extends FunSuite {
assert(group0.size === 10)
assert(group0.attributes.isEmpty)
assert(group0 === AttributeGroup.fromMetadata(group0.toMetadataImpl, group0.name))
- assert(group0 === AttributeGroup.fromStructField(group0.toStructField))
+ assert(group0 === AttributeGroup.fromStructField(group0.toStructField()))
val group1 = new AttributeGroup("item")
assert(group1.name === "item")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
index 6ec35b0365..3e1a7196e3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
@@ -36,9 +36,9 @@ class AttributeSuite extends FunSuite {
assert(attr.max.isEmpty)
assert(attr.std.isEmpty)
assert(attr.sparsity.isEmpty)
- assert(attr.toMetadata() === metadata)
- assert(attr.toMetadata(withType = false) === metadata)
- assert(attr.toMetadata(withType = true) === metadataWithType)
+ assert(attr.toMetadataImpl() === metadata)
+ assert(attr.toMetadataImpl(withType = false) === metadata)
+ assert(attr.toMetadataImpl(withType = true) === metadataWithType)
assert(attr === Attribute.fromMetadata(metadata))
assert(attr === Attribute.fromMetadata(metadataWithType))
intercept[NoSuchElementException] {
@@ -59,9 +59,9 @@ class AttributeSuite extends FunSuite {
assert(!attr.isNominal)
assert(attr.name === Some(name))
assert(attr.index === Some(index))
- assert(attr.toMetadata() === metadata)
- assert(attr.toMetadata(withType = false) === metadata)
- assert(attr.toMetadata(withType = true) === metadataWithType)
+ assert(attr.toMetadataImpl() === metadata)
+ assert(attr.toMetadataImpl(withType = false) === metadata)
+ assert(attr.toMetadataImpl(withType = true) === metadataWithType)
assert(attr === Attribute.fromMetadata(metadata))
assert(attr === Attribute.fromMetadata(metadataWithType))
val field = attr.toStructField()
@@ -81,7 +81,7 @@ class AttributeSuite extends FunSuite {
assert(attr2.max === Some(1.0))
assert(attr2.std === Some(0.5))
assert(attr2.sparsity === Some(0.3))
- assert(attr2 === Attribute.fromMetadata(attr2.toMetadata()))
+ assert(attr2 === Attribute.fromMetadata(attr2.toMetadataImpl()))
}
test("bad numeric attributes") {
@@ -105,9 +105,9 @@ class AttributeSuite extends FunSuite {
assert(attr.values.isEmpty)
assert(attr.numValues.isEmpty)
assert(attr.isOrdinal.isEmpty)
- assert(attr.toMetadata() === metadata)
- assert(attr.toMetadata(withType = true) === metadata)
- assert(attr.toMetadata(withType = false) === metadataWithoutType)
+ assert(attr.toMetadataImpl() === metadata)
+ assert(attr.toMetadataImpl(withType = true) === metadata)
+ assert(attr.toMetadataImpl(withType = false) === metadataWithoutType)
assert(attr === Attribute.fromMetadata(metadata))
assert(attr === NominalAttribute.fromMetadata(metadataWithoutType))
intercept[NoSuchElementException] {
@@ -135,9 +135,9 @@ class AttributeSuite extends FunSuite {
assert(attr.values === Some(values))
assert(attr.indexOf("medium") === 1)
assert(attr.getValue(1) === "medium")
- assert(attr.toMetadata() === metadata)
- assert(attr.toMetadata(withType = true) === metadata)
- assert(attr.toMetadata(withType = false) === metadataWithoutType)
+ assert(attr.toMetadataImpl() === metadata)
+ assert(attr.toMetadataImpl(withType = true) === metadata)
+ assert(attr.toMetadataImpl(withType = false) === metadataWithoutType)
assert(attr === Attribute.fromMetadata(metadata))
assert(attr === NominalAttribute.fromMetadata(metadataWithoutType))
assert(attr.withoutIndex === Attribute.fromStructField(attr.toStructField()))
@@ -147,8 +147,8 @@ class AttributeSuite extends FunSuite {
assert(attr2.index.isEmpty)
assert(attr2.values.get === Array("small", "medium", "large", "x-large"))
assert(attr2.indexOf("x-large") === 3)
- assert(attr2 === Attribute.fromMetadata(attr2.toMetadata()))
- assert(attr2 === NominalAttribute.fromMetadata(attr2.toMetadata(withType = false)))
+ assert(attr2 === Attribute.fromMetadata(attr2.toMetadataImpl()))
+ assert(attr2 === NominalAttribute.fromMetadata(attr2.toMetadataImpl(withType = false)))
}
test("bad nominal attributes") {
@@ -168,9 +168,9 @@ class AttributeSuite extends FunSuite {
assert(attr.name.isEmpty)
assert(attr.index.isEmpty)
assert(attr.values.isEmpty)
- assert(attr.toMetadata() === metadata)
- assert(attr.toMetadata(withType = true) === metadata)
- assert(attr.toMetadata(withType = false) === metadataWithoutType)
+ assert(attr.toMetadataImpl() === metadata)
+ assert(attr.toMetadataImpl(withType = true) === metadata)
+ assert(attr.toMetadataImpl(withType = false) === metadataWithoutType)
assert(attr === Attribute.fromMetadata(metadata))
assert(attr === BinaryAttribute.fromMetadata(metadataWithoutType))
intercept[NoSuchElementException] {
@@ -196,9 +196,9 @@ class AttributeSuite extends FunSuite {
assert(attr.name === Some(name))
assert(attr.index === Some(index))
assert(attr.values.get === values)
- assert(attr.toMetadata() === metadata)
- assert(attr.toMetadata(withType = true) === metadata)
- assert(attr.toMetadata(withType = false) === metadataWithoutType)
+ assert(attr.toMetadataImpl() === metadata)
+ assert(attr.toMetadataImpl(withType = true) === metadata)
+ assert(attr.toMetadataImpl(withType = false) === metadataWithoutType)
assert(attr === Attribute.fromMetadata(metadata))
assert(attr === BinaryAttribute.fromMetadata(metadataWithoutType))
assert(attr.withoutIndex === Attribute.fromStructField(attr.toStructField()))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
new file mode 100644
index 0000000000..af88595df5
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -0,0 +1,274 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
+ DecisionTreeSuite => OldDecisionTreeSuite}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext {
+
+ import DecisionTreeClassifierSuite.compareAPIs
+
+ private var categoricalDataPointsRDD: RDD[LabeledPoint] = _
+ private var orderedLabeledPointsWithLabel0RDD: RDD[LabeledPoint] = _
+ private var orderedLabeledPointsWithLabel1RDD: RDD[LabeledPoint] = _
+ private var categoricalDataPointsForMulticlassRDD: RDD[LabeledPoint] = _
+ private var continuousDataPointsForMulticlassRDD: RDD[LabeledPoint] = _
+ private var categoricalDataPointsForMulticlassForOrderedFeaturesRDD: RDD[LabeledPoint] = _
+
+ override def beforeAll() {
+ super.beforeAll()
+ categoricalDataPointsRDD =
+ sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints())
+ orderedLabeledPointsWithLabel0RDD =
+ sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel0())
+ orderedLabeledPointsWithLabel1RDD =
+ sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel1())
+ categoricalDataPointsForMulticlassRDD =
+ sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlass())
+ continuousDataPointsForMulticlassRDD =
+ sc.parallelize(OldDecisionTreeSuite.generateContinuousDataPointsForMulticlass())
+ categoricalDataPointsForMulticlassForOrderedFeaturesRDD = sc.parallelize(
+ OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures())
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests calling train()
+ /////////////////////////////////////////////////////////////////////////////
+
+ test("Binary classification stump with ordered categorical features") {
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("gini")
+ .setMaxDepth(2)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 3, 1-> 3)
+ val numClasses = 2
+ compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses)
+ }
+
+ test("Binary classification stump with fixed labels 0,1 for Entropy,Gini") {
+ val dt = new DecisionTreeClassifier()
+ .setMaxDepth(3)
+ .setMaxBins(100)
+ val numClasses = 2
+ Array(orderedLabeledPointsWithLabel0RDD, orderedLabeledPointsWithLabel1RDD).foreach { rdd =>
+ DecisionTreeClassifier.supportedImpurities.foreach { impurity =>
+ dt.setImpurity(impurity)
+ compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
+ }
+ }
+ }
+
+ test("Multiclass classification stump with 3-ary (unordered) categorical features") {
+ val rdd = categoricalDataPointsForMulticlassRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ val numClasses = 3
+ val categoricalFeatures = Map(0 -> 3, 1 -> 3)
+ compareAPIs(rdd, dt, categoricalFeatures, numClasses)
+ }
+
+ test("Binary classification stump with 1 continuous feature, to check off-by-1 error") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(2.0)),
+ LabeledPoint(1.0, Vectors.dense(3.0)))
+ val rdd = sc.parallelize(arr)
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ val numClasses = 2
+ compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
+ }
+
+ test("Binary classification stump with 2 continuous features") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0)))))
+ val rdd = sc.parallelize(arr)
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ val numClasses = 2
+ compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
+ }
+
+ test("Multiclass classification stump with unordered categorical features," +
+ " with just enough bins") {
+ val maxBins = 2 * (math.pow(2, 3 - 1).toInt - 1) // just enough bins to allow unordered features
+ val rdd = categoricalDataPointsForMulticlassRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(maxBins)
+ val categoricalFeatures = Map(0 -> 3, 1 -> 3)
+ val numClasses = 3
+ compareAPIs(rdd, dt, categoricalFeatures, numClasses)
+ }
+
+ test("Multiclass classification stump with continuous features") {
+ val rdd = continuousDataPointsForMulticlassRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(100)
+ val numClasses = 3
+ compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
+ }
+
+ test("Multiclass classification stump with continuous + unordered categorical features") {
+ val rdd = continuousDataPointsForMulticlassRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 3)
+ val numClasses = 3
+ compareAPIs(rdd, dt, categoricalFeatures, numClasses)
+ }
+
+ test("Multiclass classification stump with 10-ary (ordered) categorical features") {
+ val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 10, 1 -> 10)
+ val numClasses = 3
+ compareAPIs(rdd, dt, categoricalFeatures, numClasses)
+ }
+
+ test("Multiclass classification tree with 10-ary (ordered) categorical features," +
+ " with just enough bins") {
+ val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(10)
+ val categoricalFeatures = Map(0 -> 10, 1 -> 10)
+ val numClasses = 3
+ compareAPIs(rdd, dt, categoricalFeatures, numClasses)
+ }
+
+ test("split must satisfy min instances per node requirements") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))))
+ val rdd = sc.parallelize(arr)
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(2)
+ .setMinInstancesPerNode(2)
+ val numClasses = 2
+ compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
+ }
+
+ test("do not choose split that does not satisfy min instance per node requirements") {
+ // if a split does not satisfy min instances per node requirements,
+ // this split is invalid, even though the information gain of split is large.
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0, 0.0)))
+ val rdd = sc.parallelize(arr)
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxBins(2)
+ .setMaxDepth(2)
+ .setMinInstancesPerNode(2)
+ val categoricalFeatures = Map(0 -> 2, 1-> 2)
+ val numClasses = 2
+ compareAPIs(rdd, dt, categoricalFeatures, numClasses)
+ }
+
+ test("split must satisfy min info gain requirements") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))))
+ val rdd = sc.parallelize(arr)
+
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(2)
+ .setMinInfoGain(1.0)
+ val numClasses = 2
+ compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests of model save/load
+ /////////////////////////////////////////////////////////////////////////////
+
+ // TODO: Reinstate test once save/load are implemented
+ /*
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ val oldModel = OldDecisionTreeSuite.createModel(OldAlgo.Classification)
+ val newModel = DecisionTreeClassificationModel.fromOld(oldModel)
+
+ // Save model, load it back, and compare.
+ try {
+ newModel.save(sc, path)
+ val sameNewModel = DecisionTreeClassificationModel.load(sc, path)
+ TreeTests.checkEqual(newModel, sameNewModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ */
+}
+
+private[ml] object DecisionTreeClassifierSuite extends FunSuite {
+
+ /**
+ * Train 2 decision trees on the given dataset, one using the old API and one using the new API.
+ * Convert the old tree to the new format, compare them, and fail if they are not exactly equal.
+ */
+ def compareAPIs(
+ data: RDD[LabeledPoint],
+ dt: DecisionTreeClassifier,
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int): Unit = {
+ val oldStrategy = dt.getOldStrategy(categoricalFeatures, numClasses)
+ val oldTree = OldDecisionTree.train(data, oldStrategy)
+ val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
+ val newTree = dt.fit(newData)
+ // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(oldTree, newTree.parent,
+ newTree.fittingParamMap, categoricalFeatures)
+ TreeTests.checkEqual(oldTreeAsNew, newTree)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index 81ef831c42..1b261b2643 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -228,7 +228,7 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
}
val attrGroup = new AttributeGroup("features", featureAttributes)
val densePoints1WithMeta =
- densePoints1.select(densePoints1("features").as("features", attrGroup.toMetadata))
+ densePoints1.select(densePoints1("features").as("features", attrGroup.toMetadata()))
val vectorIndexer = getIndexer.setMaxCategories(2)
val model = vectorIndexer.fit(densePoints1WithMeta)
// Check that ML metadata are preserved.
diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
new file mode 100644
index 0000000000..2e57d4ce37
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.impl
+
+import scala.collection.JavaConverters._
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
+import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.tree.{DecisionTreeModel, InternalNode, LeafNode, Node}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{SQLContext, DataFrame}
+
+
+private[ml] object TreeTests extends FunSuite {
+
+ /**
+ * Convert the given data to a DataFrame, and set the features and label metadata.
+ * @param data Dataset. Categorical features and labels must already have 0-based indices.
+ * This must be non-empty.
+ * @param categoricalFeatures Map: categorical feature index -> number of distinct values
+ * @param numClasses Number of classes label can take. If 0, mark as continuous.
+ * @return DataFrame with metadata
+ */
+ def setMetadata(
+ data: RDD[LabeledPoint],
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int): DataFrame = {
+ val sqlContext = new SQLContext(data.sparkContext)
+ import sqlContext.implicits._
+ val df = data.toDF()
+ val numFeatures = data.first().features.size
+ val featuresAttributes = Range(0, numFeatures).map { feature =>
+ if (categoricalFeatures.contains(feature)) {
+ NominalAttribute.defaultAttr.withIndex(feature).withNumValues(categoricalFeatures(feature))
+ } else {
+ NumericAttribute.defaultAttr.withIndex(feature)
+ }
+ }.toArray
+ val featuresMetadata = new AttributeGroup("features", featuresAttributes).toMetadata()
+ val labelAttribute = if (numClasses == 0) {
+ NumericAttribute.defaultAttr.withName("label")
+ } else {
+ NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses)
+ }
+ val labelMetadata = labelAttribute.toMetadata()
+ df.select(df("features").as("features", featuresMetadata),
+ df("label").as("label", labelMetadata))
+ }
+
+ /** Java-friendly version of [[setMetadata()]] */
+ def setMetadata(
+ data: JavaRDD[LabeledPoint],
+ categoricalFeatures: java.util.Map[java.lang.Integer, java.lang.Integer],
+ numClasses: Int): DataFrame = {
+ setMetadata(data.rdd, categoricalFeatures.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
+ numClasses)
+ }
+
+ /**
+ * Check if the two trees are exactly the same.
+ * Note: I hesitate to override Node.equals since it could cause problems if users
+ * make mistakes such as creating loops of Nodes.
+ * If the trees are not equal, this prints the two trees and throws an exception.
+ */
+ def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = {
+ try {
+ checkEqual(a.rootNode, b.rootNode)
+ } catch {
+ case ex: Exception =>
+ throw new AssertionError("checkEqual failed since the two trees were not identical.\n" +
+ "TREE A:\n" + a.toDebugString + "\n" +
+ "TREE B:\n" + b.toDebugString + "\n", ex)
+ }
+ }
+
+ /**
+ * Return true iff the two nodes and their descendants are exactly the same.
+ * Note: I hesitate to override Node.equals since it could cause problems if users
+ * make mistakes such as creating loops of Nodes.
+ */
+ private def checkEqual(a: Node, b: Node): Unit = {
+ assert(a.prediction === b.prediction)
+ assert(a.impurity === b.impurity)
+ (a, b) match {
+ case (aye: InternalNode, bee: InternalNode) =>
+ assert(aye.split === bee.split)
+ checkEqual(aye.leftChild, bee.leftChild)
+ checkEqual(aye.rightChild, bee.rightChild)
+ case (aye: LeafNode, bee: LeafNode) => // do nothing
+ case _ =>
+ throw new AssertionError("Found mismatched nodes")
+ }
+ }
+
+ // TODO: Reinstate after adding ensembles
+ /**
+ * Check if the two models are exactly the same.
+ * If the models are not equal, this throws an exception.
+ */
+ /*
+ def checkEqual(a: TreeEnsembleModel, b: TreeEnsembleModel): Unit = {
+ try {
+ a.getTrees.zip(b.getTrees).foreach { case (treeA, treeB) =>
+ TreeTests.checkEqual(treeA, treeB)
+ }
+ assert(a.getTreeWeights === b.getTreeWeights)
+ } catch {
+ case ex: Exception => throw new AssertionError(
+ "checkEqual failed since the two tree ensembles were not identical")
+ }
+ }
+ */
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
new file mode 100644
index 0000000000..0b40fe33fa
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
+ DecisionTreeSuite => OldDecisionTreeSuite}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext {
+
+ import DecisionTreeRegressorSuite.compareAPIs
+
+ private var categoricalDataPointsRDD: RDD[LabeledPoint] = _
+
+ override def beforeAll() {
+ super.beforeAll()
+ categoricalDataPointsRDD =
+ sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints())
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests calling train()
+ /////////////////////////////////////////////////////////////////////////////
+
+ test("Regression stump with 3-ary (ordered) categorical features") {
+ val dt = new DecisionTreeRegressor()
+ .setImpurity("variance")
+ .setMaxDepth(2)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 3, 1-> 3)
+ compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures)
+ }
+
+ test("Regression stump with binary (ordered) categorical features") {
+ val dt = new DecisionTreeRegressor()
+ .setImpurity("variance")
+ .setMaxDepth(2)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 2, 1-> 2)
+ compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures)
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests of model save/load
+ /////////////////////////////////////////////////////////////////////////////
+
+ // TODO: test("model save/load")
+}
+
+private[ml] object DecisionTreeRegressorSuite extends FunSuite {
+
+ /**
+ * Train 2 decision trees on the given dataset, one using the old API and one using the new API.
+ * Convert the old tree to the new format, compare them, and fail if they are not exactly equal.
+ */
+ def compareAPIs(
+ data: RDD[LabeledPoint],
+ dt: DecisionTreeRegressor,
+ categoricalFeatures: Map[Int, Int]): Unit = {
+ val oldStrategy = dt.getOldStrategy(categoricalFeatures)
+ val oldTree = OldDecisionTree.train(data, oldStrategy)
+ val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
+ val newTree = dt.fit(newData)
+ // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(oldTree, newTree.parent,
+ newTree.fittingParamMap, categoricalFeatures)
+ TreeTests.checkEqual(oldTreeAsNew, newTree)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 4c162df810..249b8eae19 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -36,6 +36,10 @@ import org.apache.spark.util.Utils
class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests examining individual elements of training
+ /////////////////////////////////////////////////////////////////////////////
+
test("Binary classification with continuous features: split and bin calculation") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
@@ -254,6 +258,165 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(bins(0).length === 0)
}
+ test("Avoid aggregation on the last level") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
+ val input = sc.parallelize(arr)
+
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
+ numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+
+ val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
+
+ val topNode = Node.emptyNode(nodeIndex = 1)
+ assert(topNode.predict.predict === Double.MinValue)
+ assert(topNode.impurity === -1.0)
+ assert(topNode.isLeaf === false)
+
+ val nodesForGroup = Map((0, Array(topNode)))
+ val treeToNodeToIndexInfo = Map((0, Map(
+ (topNode.id, new RandomForest.NodeIndexInfo(0, None))
+ )))
+ val nodeQueue = new mutable.Queue[(Int, Node)]()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
+ nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
+
+ // don't enqueue leaf nodes into node queue
+ assert(nodeQueue.isEmpty)
+
+ // set impurity and predict for topNode
+ assert(topNode.predict.predict !== Double.MinValue)
+ assert(topNode.impurity !== -1.0)
+
+ // set impurity and predict for child nodes
+ assert(topNode.leftNode.get.predict.predict === 0.0)
+ assert(topNode.rightNode.get.predict.predict === 1.0)
+ assert(topNode.leftNode.get.impurity === 0.0)
+ assert(topNode.rightNode.get.impurity === 0.0)
+ }
+
+ test("Avoid aggregation if impurity is 0.0") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
+ val input = sc.parallelize(arr)
+
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+ numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+
+ val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
+
+ val topNode = Node.emptyNode(nodeIndex = 1)
+ assert(topNode.predict.predict === Double.MinValue)
+ assert(topNode.impurity === -1.0)
+ assert(topNode.isLeaf === false)
+
+ val nodesForGroup = Map((0, Array(topNode)))
+ val treeToNodeToIndexInfo = Map((0, Map(
+ (topNode.id, new RandomForest.NodeIndexInfo(0, None))
+ )))
+ val nodeQueue = new mutable.Queue[(Int, Node)]()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
+ nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
+
+ // don't enqueue a node into node queue if its impurity is 0.0
+ assert(nodeQueue.isEmpty)
+
+ // set impurity and predict for topNode
+ assert(topNode.predict.predict !== Double.MinValue)
+ assert(topNode.impurity !== -1.0)
+
+ // set impurity and predict for child nodes
+ assert(topNode.leftNode.get.predict.predict === 0.0)
+ assert(topNode.rightNode.get.predict.predict === 1.0)
+ assert(topNode.leftNode.get.impurity === 0.0)
+ assert(topNode.rightNode.get.impurity === 0.0)
+ }
+
+ test("Second level node building with vs. without groups") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+ assert(splits.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins.length === 2)
+ assert(bins(0).length === 100)
+
+ // Train a 1-node model
+ val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1,
+ numClasses = 2, maxBins = 100)
+ val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
+ val rootNode1 = modelOneNode.topNode.deepCopy()
+ val rootNode2 = modelOneNode.topNode.deepCopy()
+ assert(rootNode1.leftNode.nonEmpty)
+ assert(rootNode1.rightNode.nonEmpty)
+
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
+
+ // Single group second level tree construction.
+ val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get)))
+ val treeToNodeToIndexInfo = Map((0, Map(
+ (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)),
+ (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None)))))
+ val nodeQueue = new mutable.Queue[(Int, Node)]()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1),
+ nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
+ val children1 = new Array[Node](2)
+ children1(0) = rootNode1.leftNode.get
+ children1(1) = rootNode1.rightNode.get
+
+ // Train one second-level node at a time.
+ val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get)))
+ val treeToNodeToIndexInfoA = Map((0, Map(
+ (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
+ nodeQueue.clear()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
+ nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue)
+ val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get)))
+ val treeToNodeToIndexInfoB = Map((0, Map(
+ (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
+ nodeQueue.clear()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
+ nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue)
+ val children2 = new Array[Node](2)
+ children2(0) = rootNode2.leftNode.get
+ children2(1) = rootNode2.rightNode.get
+
+ // Verify whether the splits obtained using single group and multiple group level
+ // construction strategies are the same.
+ for (i <- 0 until 2) {
+ assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0)
+ assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0)
+ assert(children1(i).split === children2(i).split)
+ assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty)
+ val stats1 = children1(i).stats.get
+ val stats2 = children2(i).stats.get
+ assert(stats1.gain === stats2.gain)
+ assert(stats1.impurity === stats2.impurity)
+ assert(stats1.leftImpurity === stats2.leftImpurity)
+ assert(stats1.rightImpurity === stats2.rightImpurity)
+ assert(children1(i).predict.predict === children2(i).predict.predict)
+ }
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests calling train()
+ /////////////////////////////////////////////////////////////////////////////
test("Binary classification stump with ordered categorical features") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
@@ -438,76 +601,6 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(rootNode.predict.predict === 1)
}
- test("Second level node building with vs. without groups") {
- val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
- assert(arr.length === 1000)
- val rdd = sc.parallelize(arr)
- val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(splits.length === 2)
- assert(splits(0).length === 99)
- assert(bins.length === 2)
- assert(bins(0).length === 100)
-
- // Train a 1-node model
- val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1,
- numClasses = 2, maxBins = 100)
- val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
- val rootNode1 = modelOneNode.topNode.deepCopy()
- val rootNode2 = modelOneNode.topNode.deepCopy()
- assert(rootNode1.leftNode.nonEmpty)
- assert(rootNode1.rightNode.nonEmpty)
-
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
-
- // Single group second level tree construction.
- val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get)))
- val treeToNodeToIndexInfo = Map((0, Map(
- (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)),
- (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None)))))
- val nodeQueue = new mutable.Queue[(Int, Node)]()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1),
- nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
- val children1 = new Array[Node](2)
- children1(0) = rootNode1.leftNode.get
- children1(1) = rootNode1.rightNode.get
-
- // Train one second-level node at a time.
- val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get)))
- val treeToNodeToIndexInfoA = Map((0, Map(
- (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
- nodeQueue.clear()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
- nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue)
- val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get)))
- val treeToNodeToIndexInfoB = Map((0, Map(
- (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
- nodeQueue.clear()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
- nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue)
- val children2 = new Array[Node](2)
- children2(0) = rootNode2.leftNode.get
- children2(1) = rootNode2.rightNode.get
-
- // Verify whether the splits obtained using single group and multiple group level
- // construction strategies are the same.
- for (i <- 0 until 2) {
- assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0)
- assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0)
- assert(children1(i).split === children2(i).split)
- assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty)
- val stats1 = children1(i).stats.get
- val stats2 = children2(i).stats.get
- assert(stats1.gain === stats2.gain)
- assert(stats1.impurity === stats2.impurity)
- assert(stats1.leftImpurity === stats2.leftImpurity)
- assert(stats1.rightImpurity === stats2.rightImpurity)
- assert(children1(i).predict.predict === children2(i).predict.predict)
- }
- }
-
test("Multiclass classification stump with 3-ary (unordered) categorical features") {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
val rdd = sc.parallelize(arr)
@@ -528,11 +621,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
}
test("Binary classification stump with 1 continuous feature, to check off-by-1 error") {
- val arr = new Array[LabeledPoint](4)
- arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0))
- arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0))
- arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0))
- arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0))
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(2.0)),
+ LabeledPoint(1.0, Vectors.dense(3.0)))
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClasses = 2)
@@ -544,11 +637,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
}
test("Binary classification stump with 2 continuous features") {
- val arr = new Array[LabeledPoint](4)
- arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
- arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
- arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
- arr(3) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0))))
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0)))))
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
@@ -668,11 +761,10 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
}
test("split must satisfy min instances per node requirements") {
- val arr = new Array[LabeledPoint](3)
- arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
- arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
- arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))
-
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))))
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini,
maxDepth = 2, numClasses = 2, minInstancesPerNode = 2)
@@ -695,11 +787,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
test("do not choose split that does not satisfy min instance per node requirements") {
// if a split does not satisfy min instances per node requirements,
// this split is invalid, even though the information gain of split is large.
- val arr = new Array[LabeledPoint](4)
- arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0))
- arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0, 1.0))
- arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
- arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0, 0.0)))
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini,
@@ -715,10 +807,10 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
}
test("split must satisfy min info gain requirements") {
- val arr = new Array[LabeledPoint](3)
- arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
- arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
- arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))))
val input = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
@@ -739,91 +831,9 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(gain == InformationGainStats.invalidInformationGainStats)
}
- test("Avoid aggregation on the last level") {
- val arr = new Array[LabeledPoint](4)
- arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
- arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
- arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
- arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
- val input = sc.parallelize(arr)
-
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
- numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
- val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
-
- val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
-
- val topNode = Node.emptyNode(nodeIndex = 1)
- assert(topNode.predict.predict === Double.MinValue)
- assert(topNode.impurity === -1.0)
- assert(topNode.isLeaf === false)
-
- val nodesForGroup = Map((0, Array(topNode)))
- val treeToNodeToIndexInfo = Map((0, Map(
- (topNode.id, new RandomForest.NodeIndexInfo(0, None))
- )))
- val nodeQueue = new mutable.Queue[(Int, Node)]()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
- nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
-
- // don't enqueue leaf nodes into node queue
- assert(nodeQueue.isEmpty)
-
- // set impurity and predict for topNode
- assert(topNode.predict.predict !== Double.MinValue)
- assert(topNode.impurity !== -1.0)
-
- // set impurity and predict for child nodes
- assert(topNode.leftNode.get.predict.predict === 0.0)
- assert(topNode.rightNode.get.predict.predict === 1.0)
- assert(topNode.leftNode.get.impurity === 0.0)
- assert(topNode.rightNode.get.impurity === 0.0)
- }
-
- test("Avoid aggregation if impurity is 0.0") {
- val arr = new Array[LabeledPoint](4)
- arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
- arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
- arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
- arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
- val input = sc.parallelize(arr)
-
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
- numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
- val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
-
- val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
-
- val topNode = Node.emptyNode(nodeIndex = 1)
- assert(topNode.predict.predict === Double.MinValue)
- assert(topNode.impurity === -1.0)
- assert(topNode.isLeaf === false)
-
- val nodesForGroup = Map((0, Array(topNode)))
- val treeToNodeToIndexInfo = Map((0, Map(
- (topNode.id, new RandomForest.NodeIndexInfo(0, None))
- )))
- val nodeQueue = new mutable.Queue[(Int, Node)]()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
- nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
-
- // don't enqueue a node into node queue if its impurity is 0.0
- assert(nodeQueue.isEmpty)
-
- // set impurity and predict for topNode
- assert(topNode.predict.predict !== Double.MinValue)
- assert(topNode.impurity !== -1.0)
-
- // set impurity and predict for child nodes
- assert(topNode.leftNode.get.predict.predict === 0.0)
- assert(topNode.rightNode.get.predict.predict === 1.0)
- assert(topNode.leftNode.get.impurity === 0.0)
- assert(topNode.rightNode.get.impurity === 0.0)
- }
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests of model save/load
+ /////////////////////////////////////////////////////////////////////////////
test("Node.subtreeIterator") {
val model = DecisionTreeSuite.createModel(Classification)
@@ -996,8 +1006,9 @@ object DecisionTreeSuite extends FunSuite {
/**
* Create a tree model. This is deterministic and contains a variety of node and feature types.
+ * TODO: Update this to be a correct tree (with matching probabilities, impurities, etc.)
*/
- private[tree] def createModel(algo: Algo): DecisionTreeModel = {
+ private[mllib] def createModel(algo: Algo): DecisionTreeModel = {
val topNode = createInternalNode(id = 1, Continuous)
val (node2, node3) = (createLeafNode(id = 2), createInternalNode(id = 3, Categorical))
val (node6, node7) = (createLeafNode(id = 6), createLeafNode(id = 7))
@@ -1017,7 +1028,7 @@ object DecisionTreeSuite extends FunSuite {
* make mistakes such as creating loops of Nodes.
* If the trees are not equal, this prints the two trees and throws an exception.
*/
- private[tree] def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = {
+ private[mllib] def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = {
try {
assert(a.algo === b.algo)
checkEqual(a.topNode, b.topNode)