aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-03-23 21:16:00 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-23 21:16:00 -0700
commitcf823bead18c5be86b36da59b4bbf935c4804d04 (patch)
tree7e48dd6b225e4f2ce670d6c5513215c914053194
parentf42eaf42bdca8bc6f390f1f31ee60faa1662489b (diff)
downloadspark-cf823bead18c5be86b36da59b4bbf935c4804d04.tar.gz
spark-cf823bead18c5be86b36da59b4bbf935c4804d04.tar.bz2
spark-cf823bead18c5be86b36da59b4bbf935c4804d04.zip
[SPARK-12183][ML][MLLIB] Remove mllib tree implementation, and wrap spark.ml one
Primary change: * Removed spark.mllib.tree.DecisionTree implementation of tree and forest learning. * spark.mllib now calls the spark.ml implementation. * Moved unit tests (of tree learning internals) from spark.mllib to spark.ml as needed. ml.tree.DecisionTreeModel * Added toOld and made ```private[spark]```, implemented for Classifier and Regressor in subclasses. These methods now use OldInformationGainStats.invalidInformationGainStats for LeafNodes in order to mimic the spark.mllib implementation. ml.tree.Node * Added ```private[tree] def deepCopy```, used by unit tests Copied developer comments from spark.mllib implementation to spark.ml one. Moving unit tests * Tree learning internals were tested by spark.mllib.tree.DecisionTreeSuite, or spark.mllib.tree.RandomForestSuite. * Those tests were all moved to spark.ml.tree.impl.RandomForestSuite. The order in the file + the test names are the same, so you should be able to compare them by opening them in 2 windows side-by-side. * I made minimal changes to each test to allow it to run. Each test makes the same checks as before, except for a few removed assertions which were checking irrelevant values. * No new unit tests were added. * mllib.tree.DecisionTreeSuite: I removed some checks of splits and bins which were not relevant to the unit tests they were in. Those same split calculations were already being tested in other unit tests, for each dataset type. **Changes of behavior** (to be noted in SPARK-13448 once this PR is merged) * spark.ml.tree.impl.RandomForest: Rather than throwing an error when maxMemoryInMB is set to too small a value (to split any node), we now allow 1 node to be split, even if its memory requirements exceed maxMemoryInMB. This involved removing the maxMemoryPerNode check in RandomForest.run, as well as modifying selectNodesToSplit(). Once this PR is merged, I will note the change of behavior on SPARK-13448. * spark.mllib.tree.DecisionTree: When a tree only has one node (root = leaf node), the "stats" field will now be empty, rather than being set to InformationGainStats.invalidInformationGainStats. This does not remove information from the tree, and it will save a bit of storage. Author: Joseph K. Bradley <joseph@databricks.com> Closes #11855 from jkbradley/remove-mllib-tree-impl.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala23
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala93
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala914
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala266
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala418
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala486
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala83
-rw-r--r--python/pyspark/ml/param/_shared_params_code_gen.py3
-rw-r--r--python/pyspark/ml/param/shared.py2
15 files changed, 541 insertions, 1780 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 6ea1abb49b..3e4b21bff6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -205,8 +205,8 @@ final class DecisionTreeClassificationModel private[ml] (
@Since("2.0.0")
lazy val featureImportances: Vector = RandomForest.featureImportances(this, numFeatures)
- /** (private[ml]) Convert to a model in the old API */
- private[ml] def toOld: OldDecisionTreeModel = {
+ /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */
+ override private[spark] def toOld: OldDecisionTreeModel = {
new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index fa7cc436f0..50ac96eb5e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -205,8 +205,8 @@ final class DecisionTreeRegressionModel private[ml] (
@Since("2.0.0")
lazy val featureImportances: Vector = RandomForest.featureImportances(this, numFeatures)
- /** Convert to a model in the old API */
- private[ml] def toOld: OldDecisionTreeModel = {
+ /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */
+ override private[spark] def toOld: OldDecisionTreeModel = {
new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
index 6507a8ad7c..b5cb378829 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -78,6 +78,9 @@ sealed abstract class Node extends Serializable {
* @return Max feature index used in a split, or -1 if there are no splits (single leaf node).
*/
private[ml] def maxSplitFeatureIndex(): Int
+
+ /** Returns a deep copy of the subtree rooted at this node. */
+ private[tree] def deepCopy(): Node
}
private[ml] object Node {
@@ -137,6 +140,10 @@ final class LeafNode private[ml] (
}
override private[ml] def maxSplitFeatureIndex(): Int = -1
+
+ override private[tree] def deepCopy(): Node = {
+ new LeafNode(prediction, impurity, impurityStats)
+ }
}
/**
@@ -203,6 +210,11 @@ final class InternalNode private[ml] (
math.max(split.featureIndex,
math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex()))
}
+
+ override private[tree] def deepCopy(): Node = {
+ new InternalNode(prediction, impurity, gain, leftChild.deepCopy(), rightChild.deepCopy(),
+ split, impurityStats)
+ }
}
private object InternalNode {
@@ -286,11 +298,12 @@ private[tree] class LearningNode(
*
* @param binnedFeatures Binned feature vector for data point.
* @param splits possible splits for all features, indexed (numFeatures)(numSplits)
- * @return Leaf index if the data point reaches a leaf.
- * Otherwise, last node reachable in tree matching this example.
- * Note: This is the global node index, i.e., the index used in the tree.
- * This index is different from the index used during training a particular
- * group of nodes on one call to [[findBestSplits()]].
+ * @return Leaf index if the data point reaches a leaf.
+ * Otherwise, last node reachable in tree matching this example.
+ * Note: This is the global node index, i.e., the index used in the tree.
+ * This index is different from the index used during training a particular
+ * group of nodes on one call to
+ * [[org.apache.spark.ml.tree.impl.RandomForest.findBestSplits()]].
*/
def predictImpl(binnedFeatures: Array[Int], splits: Array[Array[Split]]): Int = {
if (this.isLeaf || this.split.isEmpty) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index afbb9d974d..7774ae64e5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -39,7 +39,48 @@ import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
-private[ml] object RandomForest extends Logging {
+/**
+ * ALGORITHM
+ *
+ * This is a sketch of the algorithm to help new developers.
+ *
+ * The algorithm partitions data by instances (rows).
+ * On each iteration, the algorithm splits a set of nodes. In order to choose the best split
+ * for a given node, sufficient statistics are collected from the distributed data.
+ * For each node, the statistics are collected to some worker node, and that worker selects
+ * the best split.
+ *
+ * This setup requires discretization of continuous features. This binning is done in the
+ * findSplits() method during initialization, after which each continuous feature becomes
+ * an ordered discretized feature with at most maxBins possible values.
+ *
+ * The main loop in the algorithm operates on a queue of nodes (nodeQueue). These nodes
+ * lie at the periphery of the tree being trained. If multiple trees are being trained at once,
+ * then this queue contains nodes from all of them. Each iteration works roughly as follows:
+ * On the master node:
+ * - Some number of nodes are pulled off of the queue (based on the amount of memory
+ * required for their sufficient statistics).
+ * - For random forests, if featureSubsetStrategy is not "all," then a subset of candidate
+ * features are chosen for each node. See method selectNodesToSplit().
+ * On worker nodes, via method findBestSplits():
+ * - The worker makes one pass over its subset of instances.
+ * - For each (tree, node, feature, split) tuple, the worker collects statistics about
+ * splitting. Note that the set of (tree, node) pairs is limited to the nodes selected
+ * from the queue for this iteration. The set of features considered can also be limited
+ * based on featureSubsetStrategy.
+ * - For each node, the statistics for that node are aggregated to a particular worker
+ * via reduceByKey(). The designated worker chooses the best (feature, split) pair,
+ * or chooses to stop splitting if the stopping criteria are met.
+ * On the master node:
+ * - The master collects all decisions about splitting nodes and updates the model.
+ * - The updated model is passed to the workers on the next iteration.
+ * This process continues until the node queue is empty.
+ *
+ * Most of the methods in this implementation support the statistics aggregation, which is
+ * the heaviest part of the computation. In general, this implementation is bound by either
+ * the cost of statistics computation on workers or by communicating the sufficient statistics.
+ */
+private[spark] object RandomForest extends Logging {
/**
* Train a random forest.
@@ -73,9 +114,9 @@ private[ml] object RandomForest extends Logging {
// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
- timer.start("findSplitsBins")
+ timer.start("findSplits")
val splits = findSplits(retaggedInput, metadata, seed)
- timer.stop("findSplitsBins")
+ timer.stop("findSplits")
logDebug("numBins: feature: number of bins")
logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
@@ -100,22 +141,6 @@ private[ml] object RandomForest extends Logging {
// TODO: Calculate memory usage more precisely.
val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
- val maxMemoryPerNode = {
- val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
- // Find numFeaturesPerNode largest bins to get an upper bound on memory usage.
- Some(metadata.numBins.zipWithIndex.sortBy(- _._1)
- .take(metadata.numFeaturesPerNode).map(_._2))
- } else {
- None
- }
- RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
- }
- require(maxMemoryPerNode <= maxMemoryUsage,
- s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," +
- " which is too small for the given features." +
- s" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}")
-
- timer.stop("init")
/*
* The main idea here is to perform group-wise training of the decision tree nodes thus
@@ -146,6 +171,8 @@ private[ml] object RandomForest extends Logging {
val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
+ timer.stop("init")
+
while (nodeQueue.nonEmpty) {
// Collect some nodes to split, and choose features for each node (if subsampling).
// Each group of nodes may come from one or multiple trees, and at multiple levels.
@@ -788,7 +815,7 @@ private[ml] object RandomForest extends Logging {
}
/**
- * Returns splits and bins for decision tree calculation.
+ * Returns splits for decision tree calculation.
* Continuous and categorical features are handled differently.
*
* Continuous features:
@@ -811,11 +838,8 @@ private[ml] object RandomForest extends Logging {
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @param metadata Learning and dataset metadata
* @param seed random seed
- * @return A tuple of (splits, bins).
- * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
- * of size (numFeatures, numSplits).
- * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]]
- * of size (numFeatures, numBins).
+ * @return Splits, an Array of [[org.apache.spark.mllib.tree.model.Split]]
+ * of size (numFeatures, numSplits)
*/
protected[tree] def findSplits(
input: RDD[LabeledPoint],
@@ -842,10 +866,10 @@ private[ml] object RandomForest extends Logging {
input.sparkContext.emptyRDD[LabeledPoint]
}
- findSplitsBinsBySorting(sampledInput, metadata, continuousFeatures)
+ findSplitsBySorting(sampledInput, metadata, continuousFeatures)
}
- private def findSplitsBinsBySorting(
+ private def findSplitsBySorting(
input: RDD[LabeledPoint],
metadata: DecisionTreeMetadata,
continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = {
@@ -885,8 +909,7 @@ private[ml] object RandomForest extends Logging {
case i if metadata.isCategorical(i) =>
// Ordered features
- // Bins correspond to feature values, so we do not need to compute splits or bins
- // beforehand. Splits are constructed as needed during training.
+ // Splits are constructed as needed during training.
Array.empty[Split]
}
splits
@@ -1025,7 +1048,9 @@ private[ml] object RandomForest extends Logging {
new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]()
var memUsage: Long = 0L
var numNodesInGroup = 0
- while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) {
+ // If maxMemoryInMB is set very small, we want to still try to split 1 node,
+ // so we allow one iteration if memUsage == 0.
+ while (nodeQueue.nonEmpty && (memUsage < maxMemoryUsage || memUsage == 0)) {
val (treeIndex, node) = nodeQueue.head
// Choose subset of features for node (if subsampling).
val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
@@ -1036,7 +1061,7 @@ private[ml] object RandomForest extends Logging {
}
// Check if enough memory remains to add this node to the group.
val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
- if (memUsage + nodeMemUsage <= maxMemoryUsage) {
+ if (memUsage + nodeMemUsage <= maxMemoryUsage || memUsage == 0) {
nodeQueue.dequeue()
mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) +=
node
@@ -1047,6 +1072,12 @@ private[ml] object RandomForest extends Logging {
numNodesInGroup += 1
memUsage += nodeMemUsage
}
+ if (memUsage > maxMemoryUsage) {
+ // If maxMemoryUsage is 0, we should still allow splitting 1 node.
+ logWarning(s"Tree learning is using approximately $memUsage bytes per iteration, which" +
+ s" exceeds requested limit maxMemoryUsage=$maxMemoryUsage. This allows splitting" +
+ s" $numNodesInGroup nodes in this iteration.")
+ }
// Convert mutable maps to immutable ones.
val nodesForGroup: Map[Int, Array[LearningNode]] =
mutableNodesForGroup.mapValues(_.toArray).toMap
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 3e72e85d10..ef40c9068f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -25,6 +25,7 @@ import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util.DefaultParamsReader
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.sql.SQLContext
/**
@@ -32,7 +33,7 @@ import org.apache.spark.sql.SQLContext
*
* TODO: Add support for predicting probabilities and raw predictions SPARK-3727
*/
-private[ml] trait DecisionTreeModel {
+private[spark] trait DecisionTreeModel {
/** Root of the decision tree */
def rootNode: Node
@@ -64,9 +65,13 @@ private[ml] trait DecisionTreeModel {
/**
* Trace down the tree, and return the largest feature index used in any split.
+ *
* @return Max feature index used in a split, or -1 if there are no splits (single leaf node).
*/
private[ml] def maxSplitFeatureIndex(): Int = rootNode.maxSplitFeatureIndex()
+
+ /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */
+ private[spark] def toOld: OldDecisionTreeModel
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 3f2d0c7198..4fbd957677 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -78,7 +78,8 @@ private[ml] trait DecisionTreeParams extends PredictorParams
"Minimum information gain for a split to be considered at a tree node.")
/**
- * Maximum memory in MB allocated to histogram aggregation.
+ * Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be
+ * split per iteration, and its aggregates may exceed this size.
* (default = 256 MB)
* @group expertParam
*/
@@ -376,7 +377,7 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase
}
-private[ml] object RandomForestParams {
+private[spark] object RandomForestParams {
// These options should be lowercase.
final val supportedFeatureSubsetStrategies: Array[String] =
Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index c40d5e3fff..21810a3b11 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -17,24 +17,19 @@
package org.apache.spark.mllib.tree
-import scala.annotation.tailrec
import scala.collection.JavaConverters._
-import scala.collection.mutable
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.impl._
import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
-import org.apache.spark.util.random.XORShiftRandom
+
/**
* A class which implements a decision tree learning algorithm for classification and regression.
@@ -281,911 +276,4 @@ object DecisionTree extends Serializable with Logging {
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
impurity, maxDepth, maxBins)
}
-
- /**
- * Get the node index corresponding to this data point.
- * This function mimics prediction, passing an example from the root node down to a leaf
- * or unsplit node; that node's index is returned.
- *
- * @param node Node in tree from which to classify the given data point.
- * @param binnedFeatures Binned feature vector for data point.
- * @param bins Possible bins for all features, indexed (numFeatures)(numBins).
- * @param unorderedFeatures Set of indices of unordered features.
- * @return Leaf index if the data point reaches a leaf.
- * Otherwise, last node reachable in tree matching this example.
- * Note: This is the global node index, i.e., the index used in the tree.
- * This index is different from the index used during training a particular
- * group of nodes on one call to [[findBestSplits()]].
- */
- @tailrec
- private def predictNodeIndex(
- node: Node,
- binnedFeatures: Array[Int],
- bins: Array[Array[Bin]],
- unorderedFeatures: Set[Int]): Int = {
- if (node.isLeaf || node.split.isEmpty) {
- // Node is either leaf, or has not yet been split.
- node.id
- } else {
- val featureIndex = node.split.get.feature
- val splitLeft = node.split.get.featureType match {
- case Continuous => {
- val binIndex = binnedFeatures(featureIndex)
- val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
- // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold]
- // We do not need to check lowSplit since bins are separated by splits.
- featureValueUpperBound <= node.split.get.threshold
- }
- case Categorical => {
- val featureValue = binnedFeatures(featureIndex)
- node.split.get.categories.contains(featureValue)
- }
- case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.")
- }
- if (node.leftNode.isEmpty || node.rightNode.isEmpty) {
- // Return index from next layer of nodes to train
- if (splitLeft) {
- Node.leftChildIndex(node.id)
- } else {
- Node.rightChildIndex(node.id)
- }
- } else {
- if (splitLeft) {
- predictNodeIndex(node.leftNode.get, binnedFeatures, bins, unorderedFeatures)
- } else {
- predictNodeIndex(node.rightNode.get, binnedFeatures, bins, unorderedFeatures)
- }
- }
- }
- }
-
- /**
- * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features.
- *
- * For ordered features, a single bin is updated.
- * For unordered features, bins correspond to subsets of categories; either the left or right bin
- * for each subset is updated.
- *
- * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
- * each (feature, bin).
- * @param treePoint Data point being aggregated.
- * @param splits Possible splits indexed (numFeatures)(numSplits).
- * @param unorderedFeatures Set of indices of unordered features.
- * @param instanceWeight Weight (importance) of instance in dataset.
- */
- private def mixedBinSeqOp(
- agg: DTStatsAggregator,
- treePoint: TreePoint,
- splits: Array[Array[Split]],
- unorderedFeatures: Set[Int],
- instanceWeight: Double,
- featuresForNode: Option[Array[Int]]): Unit = {
- val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
- // Use subsampled features
- featuresForNode.get.length
- } else {
- // Use all features
- agg.metadata.numFeatures
- }
- // Iterate over features.
- var featureIndexIdx = 0
- while (featureIndexIdx < numFeaturesPerNode) {
- val featureIndex = if (featuresForNode.nonEmpty) {
- featuresForNode.get.apply(featureIndexIdx)
- } else {
- featureIndexIdx
- }
- if (unorderedFeatures.contains(featureIndex)) {
- // Unordered feature
- val featureValue = treePoint.binnedFeatures(featureIndex)
- val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx)
- // Update the left or right bin for each split.
- val numSplits = agg.metadata.numSplits(featureIndex)
- var splitIndex = 0
- while (splitIndex < numSplits) {
- if (splits(featureIndex)(splitIndex).categories.contains(featureValue)) {
- agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
- instanceWeight)
- }
- splitIndex += 1
- }
- } else {
- // Ordered feature
- val binIndex = treePoint.binnedFeatures(featureIndex)
- agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight)
- }
- featureIndexIdx += 1
- }
- }
-
- /**
- * Helper for binSeqOp, for regression and for classification with only ordered features.
- *
- * For each feature, the sufficient statistics of one bin are updated.
- *
- * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
- * each (feature, bin).
- * @param treePoint Data point being aggregated.
- * @param instanceWeight Weight (importance) of instance in dataset.
- */
- private def orderedBinSeqOp(
- agg: DTStatsAggregator,
- treePoint: TreePoint,
- instanceWeight: Double,
- featuresForNode: Option[Array[Int]]): Unit = {
- val label = treePoint.label
-
- // Iterate over features.
- if (featuresForNode.nonEmpty) {
- // Use subsampled features
- var featureIndexIdx = 0
- while (featureIndexIdx < featuresForNode.get.length) {
- val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
- agg.update(featureIndexIdx, binIndex, label, instanceWeight)
- featureIndexIdx += 1
- }
- } else {
- // Use all features
- val numFeatures = agg.metadata.numFeatures
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- val binIndex = treePoint.binnedFeatures(featureIndex)
- agg.update(featureIndex, binIndex, label, instanceWeight)
- featureIndex += 1
- }
- }
- }
-
- /**
- * Given a group of nodes, this finds the best split for each node.
- *
- * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]].
- * @param metadata Learning and dataset metadata.
- * @param topNodes Root node for each tree. Used for matching instances with nodes.
- * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree.
- * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
- * where nodeIndexInfo stores the index in the group and the
- * feature subsets (if using feature subsets).
- * @param splits Possible splits for all features, indexed (numFeatures)(numSplits).
- * @param bins Possible bins for all features, indexed (numFeatures)(numBins).
- * @param nodeQueue Queue of nodes to split, with values (treeIndex, node).
- * Updated with new non-leaf nodes which are created.
- * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
- * each value in the array is the data point's node Id
- * for a corresponding tree. This is used to prevent the need
- * to pass the entire tree to the executors during
- * the node stat aggregation phase.
- */
- private[tree] def findBestSplits(
- input: RDD[BaggedPoint[TreePoint]],
- metadata: DecisionTreeMetadata,
- topNodes: Array[Node],
- nodesForGroup: Map[Int, Array[Node]],
- treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
- splits: Array[Array[Split]],
- bins: Array[Array[Bin]],
- nodeQueue: mutable.Queue[(Int, Node)],
- timer: TimeTracker = new TimeTracker,
- nodeIdCache: Option[NodeIdCache] = None): Unit = {
-
- /*
- * The high-level descriptions of the best split optimizations are noted here.
- *
- * *Group-wise training*
- * We perform bin calculations for groups of nodes to reduce the number of
- * passes over the data. Each iteration requires more computation and storage,
- * but saves several iterations over the data.
- *
- * *Bin-wise computation*
- * We use a bin-wise best split computation strategy instead of a straightforward best split
- * computation strategy. Instead of analyzing each sample for contribution to the left/right
- * child node impurity of every split, we first categorize each feature of a sample into a
- * bin. We exploit this structure to calculate aggregates for bins and then use these aggregates
- * to calculate information gain for each split.
- *
- * *Aggregation over partitions*
- * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know
- * the number of splits in advance. Thus, we store the aggregates (at the appropriate
- * indices) in a single array for all bins and rely upon the RDD aggregate method to
- * drastically reduce the communication overhead.
- */
-
- // numNodes: Number of nodes in this group
- val numNodes = nodesForGroup.values.map(_.length).sum
- logDebug("numNodes = " + numNodes)
- logDebug("numFeatures = " + metadata.numFeatures)
- logDebug("numClasses = " + metadata.numClasses)
- logDebug("isMulticlass = " + metadata.isMulticlass)
- logDebug("isMulticlassWithCategoricalFeatures = " +
- metadata.isMulticlassWithCategoricalFeatures)
- logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString)
-
- /**
- * Performs a sequential aggregation over a partition for a particular tree and node.
- *
- * For each feature, the aggregate sufficient statistics are updated for the relevant
- * bins.
- *
- * @param treeIndex Index of the tree that we want to perform aggregation for.
- * @param nodeInfo The node info for the tree node.
- * @param agg Array storing aggregate calculation, with a set of sufficient statistics
- * for each (node, feature, bin).
- * @param baggedPoint Data point being aggregated.
- */
- def nodeBinSeqOp(
- treeIndex: Int,
- nodeInfo: RandomForest.NodeIndexInfo,
- agg: Array[DTStatsAggregator],
- baggedPoint: BaggedPoint[TreePoint]): Unit = {
- if (nodeInfo != null) {
- val aggNodeIndex = nodeInfo.nodeIndexInGroup
- val featuresForNode = nodeInfo.featureSubset
- val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
- if (metadata.unorderedFeatures.isEmpty) {
- orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
- } else {
- mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
- metadata.unorderedFeatures, instanceWeight, featuresForNode)
- }
- agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight)
- }
- }
-
- /**
- * Performs a sequential aggregation over a partition.
- *
- * Each data point contributes to one node. For each feature,
- * the aggregate sufficient statistics are updated for the relevant bins.
- *
- * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
- * each (node, feature, bin).
- * @param baggedPoint Data point being aggregated.
- * @return Array of decision tree statistics.
- */
- def binSeqOp(
- agg: Array[DTStatsAggregator],
- baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
- treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
- val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,
- bins, metadata.unorderedFeatures)
- nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
- }
-
- agg
- }
-
- /**
- * Do the same thing as binSeqOp, but with nodeIdCache.
- */
- def binSeqOpWithNodeIdCache(
- agg: Array[DTStatsAggregator],
- dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = {
- treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
- val baggedPoint = dataPoint._1
- val nodeIdCache = dataPoint._2
- val nodeIndex = nodeIdCache(treeIndex)
- nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
- }
-
- agg
- }
-
- /**
- * Get node index in group --> features indices map,
- * which is a short cut to find feature indices for a node given node index in group
- *
- * @param treeToNodeToIndexInfo
- * @return
- */
- def getNodeToFeatures(treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]])
- : Option[Map[Int, Array[Int]]] = if (!metadata.subsamplingFeatures) {
- None
- } else {
- val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]()
- treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo =>
- nodeIdToNodeInfo.values.foreach { nodeIndexInfo =>
- assert(nodeIndexInfo.featureSubset.isDefined)
- mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get
- }
- }
- Some(mutableNodeToFeatures.toMap)
- }
-
- // array of nodes to train indexed by node index in group
- val nodes = new Array[Node](numNodes)
- nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
- nodesForTree.foreach { node =>
- nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
- }
- }
-
- // Calculate best splits for all nodes in the group
- timer.start("chooseSplits")
-
- // In each partition, iterate all instances and compute aggregate stats for each node,
- // yield an (nodeIndex, nodeAggregateStats) pair for each node.
- // After a `reduceByKey` operation,
- // stats of a node will be shuffled to a particular partition and be combined together,
- // then best splits for nodes are found there.
- // Finally, only best Splits for nodes are collected to driver to construct decision tree.
- val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
- val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
-
- val partitionAggregates: RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
- input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
- // Construct a nodeStatsAggregators array to hold node aggregate stats,
- // each node will have a nodeStatsAggregator
- val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
- val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
- Some(nodeToFeatures(nodeIndex))
- }
- new DTStatsAggregator(metadata, featuresForNode)
- }
-
- // iterator all instances in current partition and update aggregate stats
- points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))
-
- // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
- // which can be combined with other partition using `reduceByKey`
- nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
- }
- } else {
- input.mapPartitions { points =>
- // Construct a nodeStatsAggregators array to hold node aggregate stats,
- // each node will have a nodeStatsAggregator
- val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
- val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
- Some(nodeToFeatures(nodeIndex))
- }
- new DTStatsAggregator(metadata, featuresForNode)
- }
-
- // iterator all instances in current partition and update aggregate stats
- points.foreach(binSeqOp(nodeStatsAggregators, _))
-
- // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
- // which can be combined with other partition using `reduceByKey`
- nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
- }
- }
-
- val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b))
- .map { case (nodeIndex, aggStats) =>
- val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures =>
- nodeToFeatures(nodeIndex)
- }
-
- // find best split for each node
- val (split: Split, stats: InformationGainStats, predict: Predict) =
- binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
- (nodeIndex, (split, stats, predict))
- }.collectAsMap()
-
- timer.stop("chooseSplits")
-
- val nodeIdUpdaters = if (nodeIdCache.nonEmpty) {
- Array.fill[mutable.Map[Int, NodeIndexUpdater]](
- metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]())
- } else {
- null
- }
-
- // Iterate over all nodes in this group.
- nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
- nodesForTree.foreach { node =>
- val nodeIndex = node.id
- val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
- val aggNodeIndex = nodeInfo.nodeIndexInGroup
- val (split: Split, stats: InformationGainStats, predict: Predict) =
- nodeToBestSplits(aggNodeIndex)
- logDebug("best split = " + split)
-
- // Extract info for this node. Create children if not leaf.
- val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth)
- assert(node.id == nodeIndex)
- node.predict = predict
- node.isLeaf = isLeaf
- node.stats = Some(stats)
- node.impurity = stats.impurity
- logDebug("Node = " + node)
-
- if (!isLeaf) {
- node.split = Some(split)
- val childIsLeaf = (Node.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
- val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
- val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
- node.leftNode = Some(Node(Node.leftChildIndex(nodeIndex),
- stats.leftPredict, stats.leftImpurity, leftChildIsLeaf))
- node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex),
- stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
-
- if (nodeIdCache.nonEmpty) {
- val nodeIndexUpdater = NodeIndexUpdater(
- split = split,
- nodeIndex = nodeIndex)
- nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater)
- }
-
- // enqueue left child and right child if they are not leaves
- if (!leftChildIsLeaf) {
- nodeQueue.enqueue((treeIndex, node.leftNode.get))
- }
- if (!rightChildIsLeaf) {
- nodeQueue.enqueue((treeIndex, node.rightNode.get))
- }
-
- logDebug("leftChildIndex = " + node.leftNode.get.id +
- ", impurity = " + stats.leftImpurity)
- logDebug("rightChildIndex = " + node.rightNode.get.id +
- ", impurity = " + stats.rightImpurity)
- }
- }
- }
-
- if (nodeIdCache.nonEmpty) {
- // Update the cache if needed.
- nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, bins)
- }
- }
-
- /**
- * Calculate the information gain for a given (feature, split) based upon left/right aggregates.
- *
- * @param leftImpurityCalculator Left node aggregates for this (feature, split).
- * @param rightImpurityCalculator Right node aggregate for this (feature, split).
- * @return Information gain and statistics for split.
- */
- private def calculateGainForSplit(
- leftImpurityCalculator: ImpurityCalculator,
- rightImpurityCalculator: ImpurityCalculator,
- metadata: DecisionTreeMetadata,
- impurity: Double): InformationGainStats = {
- val leftCount = leftImpurityCalculator.count
- val rightCount = rightImpurityCalculator.count
-
- // If left child or right child doesn't satisfy minimum instances per node,
- // then this split is invalid, return invalid information gain stats.
- if ((leftCount < metadata.minInstancesPerNode) ||
- (rightCount < metadata.minInstancesPerNode)) {
- return InformationGainStats.invalidInformationGainStats
- }
-
- val totalCount = leftCount + rightCount
-
- val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
- val rightImpurity = rightImpurityCalculator.calculate()
-
- val leftWeight = leftCount / totalCount.toDouble
- val rightWeight = rightCount / totalCount.toDouble
-
- val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
-
- // if information gain doesn't satisfy minimum information gain,
- // then this split is invalid, return invalid information gain stats.
- if (gain < metadata.minInfoGain) {
- return InformationGainStats.invalidInformationGainStats
- }
-
- // calculate left and right predict
- val leftPredict = calculatePredict(leftImpurityCalculator)
- val rightPredict = calculatePredict(rightImpurityCalculator)
-
- new InformationGainStats(gain, impurity, leftImpurity, rightImpurity,
- leftPredict, rightPredict)
- }
-
- private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = {
- val predict = impurityCalculator.predict
- val prob = impurityCalculator.prob(predict)
- new Predict(predict, prob)
- }
-
- /**
- * Calculate predict value for current node, given stats of any split.
- * Note that this function is called only once for each node.
- *
- * @param leftImpurityCalculator Left node aggregates for a split.
- * @param rightImpurityCalculator Right node aggregates for a split.
- * @return Predict value and impurity for current node.
- */
- private def calculatePredictImpurity(
- leftImpurityCalculator: ImpurityCalculator,
- rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = {
- val parentNodeAgg = leftImpurityCalculator.copy
- parentNodeAgg.add(rightImpurityCalculator)
- val predict = calculatePredict(parentNodeAgg)
- val impurity = parentNodeAgg.calculate()
-
- (predict, impurity)
- }
-
- /**
- * Find the best split for a node.
- *
- * @param binAggregates Bin statistics.
- * @return Tuple for best split: (Split, information gain, prediction at node).
- */
- private[tree] def binsToBestSplit(
- binAggregates: DTStatsAggregator,
- splits: Array[Array[Split]],
- featuresForNode: Option[Array[Int]],
- node: Node): (Split, InformationGainStats, Predict) = {
-
- // calculate predict and impurity if current node is top node
- val level = Node.indexToLevel(node.id)
- var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) {
- None
- } else {
- Some((node.predict, node.impurity))
- }
-
- // For each (feature, split), calculate the gain, and select the best (feature, split).
- val (bestSplit, bestSplitStats) =
- Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
- val featureIndex = if (featuresForNode.nonEmpty) {
- featuresForNode.get.apply(featureIndexIdx)
- } else {
- featureIndexIdx
- }
- val numSplits = binAggregates.metadata.numSplits(featureIndex)
- if (binAggregates.metadata.isContinuous(featureIndex)) {
- // Cumulative sum (scanLeft) of bin statistics.
- // Afterwards, binAggregates for a bin is the sum of aggregates for
- // that bin + all preceding bins.
- val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
- var splitIndex = 0
- while (splitIndex < numSplits) {
- binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
- splitIndex += 1
- }
- // Find best split.
- val (bestFeatureSplitIndex, bestFeatureGainStats) =
- Range(0, numSplits).map { case splitIdx =>
- val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
- val rightChildStats =
- binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
- rightChildStats.subtract(leftChildStats)
- predictWithImpurity = Some(predictWithImpurity.getOrElse(
- calculatePredictImpurity(leftChildStats, rightChildStats)))
- val gainStats = calculateGainForSplit(leftChildStats,
- rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
- (splitIdx, gainStats)
- }.maxBy(_._2.gain)
- (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
- } else if (binAggregates.metadata.isUnordered(featureIndex)) {
- // Unordered categorical feature
- val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx)
- val (bestFeatureSplitIndex, bestFeatureGainStats) =
- Range(0, numSplits).map { splitIndex =>
- val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
- val rightChildStats = binAggregates.getParentImpurityCalculator()
- .subtract(leftChildStats)
- predictWithImpurity = Some(predictWithImpurity.getOrElse(
- calculatePredictImpurity(leftChildStats, rightChildStats)))
- val gainStats = calculateGainForSplit(leftChildStats,
- rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
- (splitIndex, gainStats)
- }.maxBy(_._2.gain)
- (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
- } else {
- // Ordered categorical feature
- val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
- val numBins = binAggregates.metadata.numBins(featureIndex)
-
- /* Each bin is one category (feature value).
- * The bins are ordered based on centroidForCategories, and this ordering determines which
- * splits are considered. (With K categories, we consider K - 1 possible splits.)
- *
- * centroidForCategories is a list: (category, centroid)
- */
- val centroidForCategories = Range(0, numBins).map { case featureValue =>
- val categoryStats =
- binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
- val centroid = if (categoryStats.count != 0) {
- if (binAggregates.metadata.isMulticlass) {
- // For categorical variables in multiclass classification,
- // the bins are ordered by the impurity of their corresponding labels.
- categoryStats.calculate()
- } else if (binAggregates.metadata.isClassification) {
- // For categorical variables in binary classification,
- // the bins are ordered by the count of class 1.
- categoryStats.stats(1)
- } else {
- // For categorical variables in regression,
- // the bins are ordered by the prediction.
- categoryStats.predict
- }
- } else {
- Double.MaxValue
- }
- (featureValue, centroid)
- }
-
- logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
-
- // bins sorted by centroids
- val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
-
- logDebug("Sorted centroids for categorical variable = " +
- categoriesSortedByCentroid.mkString(","))
-
- // Cumulative sum (scanLeft) of bin statistics.
- // Afterwards, binAggregates for a bin is the sum of aggregates for
- // that bin + all preceding bins.
- var splitIndex = 0
- while (splitIndex < numSplits) {
- val currentCategory = categoriesSortedByCentroid(splitIndex)._1
- val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
- binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
- splitIndex += 1
- }
- // lastCategory = index of bin with total aggregates for this (node, feature)
- val lastCategory = categoriesSortedByCentroid.last._1
- // Find best split.
- val (bestFeatureSplitIndex, bestFeatureGainStats) =
- Range(0, numSplits).map { splitIndex =>
- val featureValue = categoriesSortedByCentroid(splitIndex)._1
- val leftChildStats =
- binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
- val rightChildStats =
- binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
- rightChildStats.subtract(leftChildStats)
- predictWithImpurity = Some(predictWithImpurity.getOrElse(
- calculatePredictImpurity(leftChildStats, rightChildStats)))
- val gainStats = calculateGainForSplit(leftChildStats,
- rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
- (splitIndex, gainStats)
- }.maxBy(_._2.gain)
- val categoriesForSplit =
- categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
- val bestFeatureSplit =
- new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
- (bestFeatureSplit, bestFeatureGainStats)
- }
- }.maxBy(_._2.gain)
-
- (bestSplit, bestSplitStats, predictWithImpurity.get._1)
- }
-
- /**
- * Returns splits and bins for decision tree calculation.
- * Continuous and categorical features are handled differently.
- *
- * Continuous features:
- * For each feature, there are numBins - 1 possible splits representing the possible binary
- * decisions at each node in the tree.
- * This finds locations (feature values) for splits using a subsample of the data.
- *
- * Categorical features:
- * For each feature, there is 1 bin per split.
- * Splits and bins are handled in 2 ways:
- * (a) "unordered features"
- * For multiclass classification with a low-arity feature
- * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
- * the feature is split based on subsets of categories.
- * (b) "ordered features"
- * For regression and binary classification,
- * and for multiclass classification with a high-arity feature,
- * there is one bin per category.
- *
- * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * @param metadata Learning and dataset metadata.
- * @return A tuple of (splits, bins).
- * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
- * of size (numFeatures, numSplits).
- * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]]
- * of size (numFeatures, numBins).
- */
- protected[tree] def findSplitsBins(
- input: RDD[LabeledPoint],
- metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = {
-
- logDebug("isMulticlass = " + metadata.isMulticlass)
-
- val numFeatures = metadata.numFeatures
-
- // Sample the input only if there are continuous features.
- val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous)
- val sampledInput = if (continuousFeatures.nonEmpty) {
- // Calculate the number of samples for approximate quantile calculation.
- val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
- val fraction = if (requiredSamples < metadata.numExamples) {
- requiredSamples.toDouble / metadata.numExamples
- } else {
- 1.0
- }
- logDebug("fraction of data used for calculating quantiles = " + fraction)
- input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt())
- } else {
- input.sparkContext.emptyRDD[LabeledPoint]
- }
-
- metadata.quantileStrategy match {
- case Sort =>
- findSplitsBinsBySorting(sampledInput, metadata, continuousFeatures)
- case MinMax =>
- throw new UnsupportedOperationException("minmax not supported yet.")
- case ApproxHist =>
- throw new UnsupportedOperationException("approximate histogram not supported yet.")
- }
- }
-
- private def findSplitsBinsBySorting(
- input: RDD[LabeledPoint],
- metadata: DecisionTreeMetadata,
- continuousFeatures: IndexedSeq[Int]): (Array[Array[Split]], Array[Array[Bin]]) = {
- def findSplits(
- featureIndex: Int,
- featureSamples: Iterable[Double]): (Int, (Array[Split], Array[Bin])) = {
- val splits = {
- val featureSplits = findSplitsForContinuousFeature(
- featureSamples,
- metadata,
- featureIndex)
- logDebug(s"featureIndex = $featureIndex, numSplits = ${featureSplits.length}")
-
- featureSplits.map(threshold => new Split(featureIndex, threshold, Continuous, Nil))
- }
-
- val bins = {
- val lowSplit = new DummyLowSplit(featureIndex, Continuous)
- val highSplit = new DummyHighSplit(featureIndex, Continuous)
-
- // tack the dummy splits on either side of the computed splits
- val allSplits = lowSplit +: splits.toSeq :+ highSplit
-
- // slide across the split points pairwise to allocate the bins
- allSplits.sliding(2).map {
- case Seq(left, right) => new Bin(left, right, Continuous, Double.MinValue)
- }.toArray
- }
-
- (featureIndex, (splits, bins))
- }
-
- val continuousSplits = {
- // reduce the parallelism for split computations when there are less
- // continuous features than input partitions. this prevents tasks from
- // being spun up that will definitely do no work.
- val numPartitions = math.min(continuousFeatures.length, input.partitions.length)
-
- input
- .flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx))))
- .groupByKey(numPartitions)
- .map { case (k, v) => findSplits(k, v) }
- .collectAsMap()
- }
-
- val numFeatures = metadata.numFeatures
- val (splits, bins) = Range(0, numFeatures).unzip {
- case i if metadata.isContinuous(i) =>
- val (split, bin) = continuousSplits(i)
- metadata.setNumSplits(i, split.length)
- (split, bin)
-
- case i if metadata.isCategorical(i) && metadata.isUnordered(i) =>
- // Unordered features
- // 2^(maxFeatureValue - 1) - 1 combinations
- val featureArity = metadata.featureArity(i)
- val split = Range(0, metadata.numSplits(i)).map { splitIndex =>
- val categories = extractMultiClassCategories(splitIndex + 1, featureArity)
- new Split(i, Double.MinValue, Categorical, categories)
- }
-
- // For unordered categorical features, there is no need to construct the bins.
- // since there is a one-to-one correspondence between the splits and the bins.
- (split.toArray, Array.empty[Bin])
-
- case i if metadata.isCategorical(i) =>
- // Ordered features
- // Bins correspond to feature values, so we do not need to compute splits or bins
- // beforehand. Splits are constructed as needed during training.
- (Array.empty[Split], Array.empty[Bin])
- }
-
- (splits.toArray, bins.toArray)
- }
-
- /**
- * Nested method to extract list of eligible categories given an index. It extracts the
- * position of ones in a binary representation of the input. If binary
- * representation of an number is 01101 (13), the output list should (3.0, 2.0,
- * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones.
- */
- private[tree] def extractMultiClassCategories(
- input: Int,
- maxFeatureValue: Int): List[Double] = {
- var categories = List[Double]()
- var j = 0
- var bitShiftedInput = input
- while (j < maxFeatureValue) {
- if (bitShiftedInput % 2 != 0) {
- // updating the list of categories.
- categories = j.toDouble :: categories
- }
- // Right shift by one
- bitShiftedInput = bitShiftedInput >> 1
- j += 1
- }
- categories
- }
-
- /**
- * Find splits for a continuous feature
- * NOTE: Returned number of splits is set based on `featureSamples` and
- * could be different from the specified `numSplits`.
- * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
- *
- * @param featureSamples Feature values of each sample.
- * @param metadata Decision tree metadata.
- * NOTE: `metadata.numbins` will be changed accordingly
- * if there are not enough splits to be found.
- * @param featureIndex Feature index to find splits.
- * @return Array of splits.
- */
- private[tree] def findSplitsForContinuousFeature(
- featureSamples: Iterable[Double],
- metadata: DecisionTreeMetadata,
- featureIndex: Int): Array[Double] = {
- require(metadata.isContinuous(featureIndex),
- "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
-
- val splits = {
- val numSplits = metadata.numSplits(featureIndex)
-
- // get count for each distinct value
- val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) {
- case ((m, cnt), x) =>
- (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1)
- }
- // sort distinct values
- val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
-
- // if possible splits is not enough or just enough, just return all possible splits
- val possibleSplits = valueCounts.length
- if (possibleSplits <= numSplits) {
- valueCounts.map(_._1)
- } else {
- // stride between splits
- val stride: Double = numSamples.toDouble / (numSplits + 1)
- logDebug("stride = " + stride)
-
- // iterate `valueCount` to find splits
- val splitsBuilder = Array.newBuilder[Double]
- var index = 1
- // currentCount: sum of counts of values that have been visited
- var currentCount = valueCounts(0)._2
- // targetCount: target value for `currentCount`.
- // If `currentCount` is closest value to `targetCount`,
- // then current value is a split threshold.
- // After finding a split threshold, `targetCount` is added by stride.
- var targetCount = stride
- while (index < valueCounts.length) {
- val previousCount = currentCount
- currentCount += valueCounts(index)._2
- val previousGap = math.abs(previousCount - targetCount)
- val currentGap = math.abs(currentCount - targetCount)
- // If adding count of current value to currentCount
- // makes the gap between currentCount and targetCount smaller,
- // previous value is a split threshold.
- if (previousGap < currentGap) {
- splitsBuilder += valueCounts(index - 1)._1
- targetCount += stride
- }
- index += 1
- }
-
- splitsBuilder.result()
- }
- }
-
- // TODO: Do not fail; just ignore the useless feature.
- assert(splits.length > 0,
- s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
- " Please remove this feature and then try again.")
-
- // the split metadata must be updated on the driver
-
- splits
- }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index ec4c034169..1841fa4a95 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -17,26 +17,22 @@
package org.apache.spark.mllib.tree
-import java.io.IOException
-
-import scala.collection.mutable
import scala.collection.JavaConverters._
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
+import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, RandomForestParams => NewRFParams}
+import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, NodeIdCache,
- TimeTracker, TreePoint}
import org.apache.spark.mllib.tree.impurity.Impurities
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
-import org.apache.spark.util.random.SamplingUtils
+
/**
* A class that implements a [[http://en.wikipedia.org/wiki/Random_forest Random Forest]]
@@ -72,47 +68,6 @@ private class RandomForest (
private val seed: Int)
extends Serializable with Logging {
- /*
- ALGORITHM
- This is a sketch of the algorithm to help new developers.
-
- The algorithm partitions data by instances (rows).
- On each iteration, the algorithm splits a set of nodes. In order to choose the best split
- for a given node, sufficient statistics are collected from the distributed data.
- For each node, the statistics are collected to some worker node, and that worker selects
- the best split.
-
- This setup requires discretization of continuous features. This binning is done in the
- findSplitsBins() method during initialization, after which each continuous feature becomes
- an ordered discretized feature with at most maxBins possible values.
-
- The main loop in the algorithm operates on a queue of nodes (nodeQueue). These nodes
- lie at the periphery of the tree being trained. If multiple trees are being trained at once,
- then this queue contains nodes from all of them. Each iteration works roughly as follows:
- On the master node:
- - Some number of nodes are pulled off of the queue (based on the amount of memory
- required for their sufficient statistics).
- - For random forests, if featureSubsetStrategy is not "all," then a subset of candidate
- features are chosen for each node. See method selectNodesToSplit().
- On worker nodes, via method findBestSplits():
- - The worker makes one pass over its subset of instances.
- - For each (tree, node, feature, split) tuple, the worker collects statistics about
- splitting. Note that the set of (tree, node) pairs is limited to the nodes selected
- from the queue for this iteration. The set of features considered can also be limited
- based on featureSubsetStrategy.
- - For each node, the statistics for that node are aggregated to a particular worker
- via reduceByKey(). The designated worker chooses the best (feature, split) pair,
- or chooses to stop splitting if the stopping criteria are met.
- On the master node:
- - The master collects all decisions about splitting nodes and updates the model.
- - The updated model is passed to the workers on the next iteration.
- This process continues until the node queue is empty.
-
- Most of the methods in this implementation support the statistics aggregation, which is
- the heaviest part of the computation. In general, this implementation is bound by either
- the cost of statistics computation on workers or by communicating the sufficient statistics.
- */
-
strategy.assertValid()
require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.")
require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy),
@@ -126,135 +81,9 @@ private class RandomForest (
* @return RandomForestModel that can be used for prediction.
*/
def run(input: RDD[LabeledPoint]): RandomForestModel = {
-
- val timer = new TimeTracker()
-
- timer.start("total")
-
- timer.start("init")
-
- val retaggedInput = input.retag(classOf[LabeledPoint])
- val metadata =
- DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
- logDebug("algo = " + strategy.algo)
- logDebug("numTrees = " + numTrees)
- logDebug("seed = " + seed)
- logDebug("maxBins = " + metadata.maxBins)
- logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
- logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)
- logDebug("subsamplingRate = " + strategy.subsamplingRate)
-
- // Find the splits and the corresponding bins (interval between the splits) using a sample
- // of the input data.
- timer.start("findSplitsBins")
- val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata)
- timer.stop("findSplitsBins")
- logDebug("numBins: feature: number of bins")
- logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
- s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
- }.mkString("\n"))
-
- // Bin feature values (TreePoint representation).
- // Cache input RDD for speedup during multiple passes.
- val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
-
- val withReplacement = if (numTrees > 1) true else false
-
- val baggedInput
- = BaggedPoint.convertToBaggedRDD(treeInput,
- strategy.subsamplingRate, numTrees,
- withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK)
-
- // depth of the decision tree
- val maxDepth = strategy.maxDepth
- require(maxDepth <= 30,
- s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
-
- // Max memory usage for aggregates
- // TODO: Calculate memory usage more precisely.
- val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
- logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
- val maxMemoryPerNode = {
- val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
- // Find numFeaturesPerNode largest bins to get an upper bound on memory usage.
- Some(metadata.numBins.zipWithIndex.sortBy(- _._1)
- .take(metadata.numFeaturesPerNode).map(_._2))
- } else {
- None
- }
- RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
- }
- require(maxMemoryPerNode <= maxMemoryUsage,
- s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," +
- " which is too small for the given features." +
- s" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}")
-
- timer.stop("init")
-
- /*
- * The main idea here is to perform group-wise training of the decision tree nodes thus
- * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup).
- * Each data sample is handled by a particular node (or it reaches a leaf and is not used
- * in lower levels).
- */
-
- // Create an RDD of node Id cache.
- // At first, all the rows belong to the root nodes (node Id == 1).
- val nodeIdCache = if (strategy.useNodeIdCache) {
- Some(NodeIdCache.init(
- data = baggedInput,
- numTrees = numTrees,
- checkpointInterval = strategy.checkpointInterval,
- initVal = 1))
- } else {
- None
- }
-
- // FIFO queue of nodes to train: (treeIndex, node)
- val nodeQueue = new mutable.Queue[(Int, Node)]()
-
- val rng = new scala.util.Random()
- rng.setSeed(seed)
-
- // Allocate and queue root nodes.
- val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1))
- Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
-
- while (nodeQueue.nonEmpty) {
- // Collect some nodes to split, and choose features for each node (if subsampling).
- // Each group of nodes may come from one or multiple trees, and at multiple levels.
- val (nodesForGroup, treeToNodeToIndexInfo) =
- RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
- // Sanity check (should never occur):
- assert(nodesForGroup.size > 0,
- s"RandomForest selected empty nodesForGroup. Error for unknown reason.")
-
- // Choose node splits, and enqueue new nodes as needed.
- timer.start("findBestSplits")
- DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
- treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)
- timer.stop("findBestSplits")
- }
-
- baggedInput.unpersist()
-
- timer.stop("total")
-
- logInfo("Internal timing for DecisionTree:")
- logInfo(s"$timer")
-
- // Delete any remaining checkpoints used for node Id cache.
- if (nodeIdCache.nonEmpty) {
- try {
- nodeIdCache.get.deleteAllCheckpoints()
- } catch {
- case e: IOException =>
- logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}")
- }
- }
-
- val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
- new RandomForestModel(strategy.algo, trees)
+ val trees: Array[NewDTModel] =
+ NewRandomForest.run(input, strategy, numTrees, featureSubsetStrategy, seed.toLong)
+ new RandomForestModel(strategy.algo, trees.map(_.toOld))
}
}
@@ -441,86 +270,5 @@ object RandomForest extends Serializable with Logging {
* List of supported feature subset sampling strategies.
*/
@Since("1.2.0")
- val supportedFeatureSubsetStrategies: Array[String] =
- Array("auto", "all", "sqrt", "log2", "onethird")
-
- private[tree] class NodeIndexInfo(
- val nodeIndexInGroup: Int,
- val featureSubset: Option[Array[Int]]) extends Serializable
-
- /**
- * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration.
- * This tracks the memory usage for aggregates and stops adding nodes when too much memory
- * will be needed; this allows an adaptive number of nodes since different nodes may require
- * different amounts of memory (if featureSubsetStrategy is not "all").
- *
- * @param nodeQueue Queue of nodes to split.
- * @param maxMemoryUsage Bound on size of aggregate statistics.
- * @return (nodesForGroup, treeToNodeToIndexInfo).
- * nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
- *
- * treeToNodeToIndexInfo holds indices selected features for each node:
- * treeIndex --> (global) node index --> (node index in group, feature indices).
- * The (global) node index is the index in the tree; the node index in group is the
- * index in [0, numNodesInGroup) of the node in this group.
- * The feature indices are None if not subsampling features.
- */
- private[tree] def selectNodesToSplit(
- nodeQueue: mutable.Queue[(Int, Node)],
- maxMemoryUsage: Long,
- metadata: DecisionTreeMetadata,
- rng: scala.util.Random): (Map[Int, Array[Node]], Map[Int, Map[Int, NodeIndexInfo]]) = {
- // Collect some nodes to split:
- // nodesForGroup(treeIndex) = nodes to split
- val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[Node]]()
- val mutableTreeToNodeToIndexInfo =
- new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]()
- var memUsage: Long = 0L
- var numNodesInGroup = 0
- while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) {
- val (treeIndex, node) = nodeQueue.head
- // Choose subset of features for node (if subsampling).
- val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
- Some(SamplingUtils.reservoirSampleAndCount(Range(0,
- metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1)
- } else {
- None
- }
- // Check if enough memory remains to add this node to the group.
- val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
- if (memUsage + nodeMemUsage <= maxMemoryUsage) {
- nodeQueue.dequeue()
- mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[Node]()) += node
- mutableTreeToNodeToIndexInfo
- .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id)
- = new NodeIndexInfo(numNodesInGroup, featureSubset)
- }
- numNodesInGroup += 1
- memUsage += nodeMemUsage
- }
- // Convert mutable maps to immutable ones.
- val nodesForGroup: Map[Int, Array[Node]] = mutableNodesForGroup.mapValues(_.toArray).toMap
- val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap
- (nodesForGroup, treeToNodeToIndexInfo)
- }
-
- /**
- * Get the number of values to be stored for this node in the bin aggregates.
- * @param featureSubset Indices of features which may be split at this node.
- * If None, then use all features.
- */
- private[tree] def aggregateSizeForNode(
- metadata: DecisionTreeMetadata,
- featureSubset: Option[Array[Int]]): Long = {
- val totalBins = if (featureSubset.nonEmpty) {
- featureSubset.get.map(featureIndex => metadata.numBins(featureIndex).toLong).sum
- } else {
- metadata.numBins.map(_.toLong).sum
- }
- if (metadata.isClassification) {
- metadata.numClasses * totalBins
- } else {
- 3 * totalBins
- }
- }
+ val supportedFeatureSubsetStrategies: Array[String] = NewRFParams.supportedFeatureSubsetStrategies
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 8a0907564e..0214db55c1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -56,7 +56,8 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
* If a split has less information gain than minInfoGain,
* this split will not be considered as a valid split.
* @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
- * 256 MB.
+ * 256 MB. If too small, then 1 node will be split per iteration, and
+ * its aggregates may exceed this size.
* @param subsamplingRate Fraction of the training data used for learning decision tree.
* @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will
* maintain a separate RDD of node Id cache for each row.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index db1e27bf70..f3dbfd96e1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -68,16 +68,6 @@ class InformationGainStats(
}
}
-private[spark] object InformationGainStats {
- /**
- * An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to
- * denote that current split doesn't satisfies minimum info gain or
- * minimum number of instances per node.
- */
- val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0,
- new Predict(0.0, 0.0), new Predict(0.0, 0.0))
-}
-
/**
* Impurity statistics for each split
* @param gain information gain value
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index 9d922291a6..361366fde7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -17,11 +17,17 @@
package org.apache.spark.ml.tree.impl
+import scala.collection.mutable
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
-import org.apache.spark.ml.tree.{ContinuousSplit, DecisionTreeModel, LeafNode, Node}
+import org.apache.spark.ml.tree._
import org.apache.spark.mllib.linalg.{Vector, Vectors}
-import org.apache.spark.mllib.tree.impurity.GiniCalculator
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata}
+import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.collection.OpenHashMap
@@ -33,6 +39,414 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
import RandomForestSuite.mapToVec
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests for split calculation
+ /////////////////////////////////////////////////////////////////////////////
+
+ test("Binary classification with continuous features: split calculation") {
+ val arr = OldDTSuite.generateOrderedLabeledPointsWithLabel1()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2, 100)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ val splits = RandomForest.findSplits(rdd, metadata, seed = 42)
+ assert(splits.length === 2)
+ assert(splits(0).length === 99)
+ }
+
+ test("Binary classification with binary (ordered) categorical features: split calculation") {
+ val arr = OldDTSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2,
+ maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2))
+
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val splits = RandomForest.findSplits(rdd, metadata, seed = 42)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
+ assert(splits.length === 2)
+ // no splits pre-computed for ordered categorical features
+ assert(splits(0).length === 0)
+ }
+
+ test("Binary classification with 3-ary (ordered) categorical features," +
+ " with no samples for one category: split calculation") {
+ val arr = OldDTSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2,
+ maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
+
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
+ val splits = RandomForest.findSplits(rdd, metadata, seed = 42)
+ assert(splits.length === 2)
+ // no splits pre-computed for ordered categorical features
+ assert(splits(0).length === 0)
+ }
+
+ test("find splits for a continuous feature") {
+ // find splits for normal case
+ {
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ Map(), Set(),
+ Array(6), Gini, QuantileStrategy.Sort,
+ 0, 0, 0.0, 0, 0
+ )
+ val featureSamples = Array.fill(200000)(math.random)
+ val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+ assert(splits.length === 5)
+ assert(fakeMetadata.numSplits(0) === 5)
+ assert(fakeMetadata.numBins(0) === 6)
+ // check returned splits are distinct
+ assert(splits.distinct.length === splits.length)
+ }
+
+ // find splits should not return identical splits
+ // when there are not enough split candidates, reduce the number of splits in metadata
+ {
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ Map(), Set(),
+ Array(5), Gini, QuantileStrategy.Sort,
+ 0, 0, 0.0, 0, 0
+ )
+ val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
+ val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+ assert(splits.length === 3)
+ // check returned splits are distinct
+ assert(splits.distinct.length === splits.length)
+ }
+
+ // find splits when most samples close to the minimum
+ {
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ Map(), Set(),
+ Array(3), Gini, QuantileStrategy.Sort,
+ 0, 0, 0.0, 0, 0
+ )
+ val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
+ val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+ assert(splits.length === 2)
+ assert(splits(0) === 2.0)
+ assert(splits(1) === 3.0)
+ }
+
+ // find splits when most samples close to the maximum
+ {
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ Map(), Set(),
+ Array(3), Gini, QuantileStrategy.Sort,
+ 0, 0, 0.0, 0, 0
+ )
+ val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
+ val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+ assert(splits.length === 1)
+ assert(splits(0) === 1.0)
+ }
+ }
+
+ test("Multiclass classification with unordered categorical features: split calculations") {
+ val arr = OldDTSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new OldStrategy(
+ OldAlgo.Classification,
+ Gini,
+ maxDepth = 2,
+ numClasses = 100,
+ maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
+
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(metadata.isUnordered(featureIndex = 0))
+ assert(metadata.isUnordered(featureIndex = 1))
+ val splits = RandomForest.findSplits(rdd, metadata, seed = 42)
+ assert(splits.length === 2)
+ assert(splits(0).length === 3)
+ assert(metadata.numSplits(0) === 3)
+ assert(metadata.numBins(0) === 3)
+ assert(metadata.numSplits(1) === 3)
+ assert(metadata.numBins(1) === 3)
+
+ // Expecting 2^2 - 1 = 3 splits per feature
+ def checkCategoricalSplit(s: Split, featureIndex: Int, leftCategories: Array[Double]): Unit = {
+ assert(s.featureIndex === featureIndex)
+ assert(s.isInstanceOf[CategoricalSplit])
+ val s0 = s.asInstanceOf[CategoricalSplit]
+ assert(s0.leftCategories === leftCategories)
+ assert(s0.numCategories === 3) // for this unit test
+ }
+ // Feature 0
+ checkCategoricalSplit(splits(0)(0), 0, Array(0.0))
+ checkCategoricalSplit(splits(0)(1), 0, Array(1.0))
+ checkCategoricalSplit(splits(0)(2), 0, Array(0.0, 1.0))
+ // Feature 1
+ checkCategoricalSplit(splits(1)(0), 1, Array(0.0))
+ checkCategoricalSplit(splits(1)(1), 1, Array(1.0))
+ checkCategoricalSplit(splits(1)(2), 1, Array(0.0, 1.0))
+ }
+
+ test("Multiclass classification with ordered categorical features: split calculations") {
+ val arr = OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
+ assert(arr.length === 3000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 100,
+ maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
+ // 2^(10-1) - 1 > 100, so categorical features will be ordered
+
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
+ val splits = RandomForest.findSplits(rdd, metadata, seed = 42)
+ assert(splits.length === 2)
+ // no splits pre-computed for ordered categorical features
+ assert(splits(0).length === 0)
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests of other algorithm internals
+ /////////////////////////////////////////////////////////////////////////////
+
+ test("extract categories from a number for multiclass classification") {
+ val l = RandomForest.extractMultiClassCategories(13, 10)
+ assert(l.length === 3)
+ assert(Seq(3.0, 2.0, 0.0) === l)
+ }
+
+ test("Avoid aggregation on the last level") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
+ val input = sc.parallelize(arr)
+
+ val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1,
+ numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+ val splits = RandomForest.findSplits(input, metadata, seed = 42)
+
+ val treeInput = TreePoint.convertToTreeRDD(input, splits, metadata)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, withReplacement = false)
+
+ val topNode = LearningNode.emptyNode(nodeIndex = 1)
+ assert(topNode.isLeaf === false)
+ assert(topNode.stats === null)
+
+ val nodesForGroup = Map((0, Array(topNode)))
+ val treeToNodeToIndexInfo = Map((0, Map(
+ (topNode.id, new RandomForest.NodeIndexInfo(0, None))
+ )))
+ val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
+ RandomForest.findBestSplits(baggedInput, metadata, Array(topNode),
+ nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue)
+
+ // don't enqueue leaf nodes into node queue
+ assert(nodeQueue.isEmpty)
+
+ // set impurity and predict for topNode
+ assert(topNode.stats !== null)
+ assert(topNode.stats.impurity > 0.0)
+
+ // set impurity and predict for child nodes
+ assert(topNode.leftChild.get.toNode.prediction === 0.0)
+ assert(topNode.rightChild.get.toNode.prediction === 1.0)
+ assert(topNode.leftChild.get.stats.impurity === 0.0)
+ assert(topNode.rightChild.get.stats.impurity === 0.0)
+ }
+
+ test("Avoid aggregation if impurity is 0.0") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
+ val input = sc.parallelize(arr)
+
+ val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 5,
+ numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+ val splits = RandomForest.findSplits(input, metadata, seed = 42)
+
+ val treeInput = TreePoint.convertToTreeRDD(input, splits, metadata)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, withReplacement = false)
+
+ val topNode = LearningNode.emptyNode(nodeIndex = 1)
+ assert(topNode.isLeaf === false)
+ assert(topNode.stats === null)
+
+ val nodesForGroup = Map((0, Array(topNode)))
+ val treeToNodeToIndexInfo = Map((0, Map(
+ (topNode.id, new RandomForest.NodeIndexInfo(0, None))
+ )))
+ val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
+ RandomForest.findBestSplits(baggedInput, metadata, Array(topNode),
+ nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue)
+
+ // don't enqueue a node into node queue if its impurity is 0.0
+ assert(nodeQueue.isEmpty)
+
+ // set impurity and predict for topNode
+ assert(topNode.stats !== null)
+ assert(topNode.stats.impurity > 0.0)
+
+ // set impurity and predict for child nodes
+ assert(topNode.leftChild.get.toNode.prediction === 0.0)
+ assert(topNode.rightChild.get.toNode.prediction === 1.0)
+ assert(topNode.leftChild.get.stats.impurity === 0.0)
+ assert(topNode.rightChild.get.stats.impurity === 0.0)
+ }
+
+ test("Use soft prediction for binary classification with ordered categorical features") {
+ // The following dataset is set up such that the best split is {1} vs. {0, 2}.
+ // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0)),
+ LabeledPoint(1.0, Vectors.dense(2.0)))
+ val input = sc.parallelize(arr)
+
+ // Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
+ val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1,
+ numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
+
+ val model = RandomForest.run(input, strategy, numTrees = 1, featureSubsetStrategy = "all",
+ seed = 42).head
+ model.rootNode match {
+ case n: InternalNode => n.split match {
+ case s: CategoricalSplit =>
+ assert(s.leftCategories === Array(1.0))
+ }
+ }
+ }
+
+ test("Second level node building with vs. without groups") {
+ val arr = OldDTSuite.generateOrderedLabeledPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ // For tree with 1 group
+ val strategy1 =
+ new OldStrategy(OldAlgo.Classification, Entropy, 3, 2, 100, maxMemoryInMB = 1000)
+ // For tree with multiple groups
+ val strategy2 =
+ new OldStrategy(OldAlgo.Classification, Entropy, 3, 2, 100, maxMemoryInMB = 0)
+
+ val tree1 = RandomForest.run(rdd, strategy1, numTrees = 1, featureSubsetStrategy = "all",
+ seed = 42).head
+ val tree2 = RandomForest.run(rdd, strategy2, numTrees = 1, featureSubsetStrategy = "all",
+ seed = 42).head
+
+ def getChildren(rootNode: Node): Array[InternalNode] = rootNode match {
+ case n: InternalNode =>
+ assert(n.leftChild.isInstanceOf[InternalNode])
+ assert(n.rightChild.isInstanceOf[InternalNode])
+ Array(n.leftChild.asInstanceOf[InternalNode], n.rightChild.asInstanceOf[InternalNode])
+ }
+
+ // Single group second level tree construction.
+ val children1 = getChildren(tree1.rootNode)
+ val children2 = getChildren(tree2.rootNode)
+
+ // Verify whether the splits obtained using single group and multiple group level
+ // construction strategies are the same.
+ for (i <- 0 until 2) {
+ assert(children1(i).gain > 0)
+ assert(children2(i).gain > 0)
+ assert(children1(i).split === children2(i).split)
+ assert(children1(i).impurity === children2(i).impurity)
+ assert(children1(i).impurityStats.stats === children2(i).impurityStats.stats)
+ assert(children1(i).leftChild.impurity === children2(i).leftChild.impurity)
+ assert(children1(i).rightChild.impurity === children2(i).rightChild.impurity)
+ assert(children1(i).prediction === children2(i).prediction)
+ }
+ }
+
+ def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: OldStrategy) {
+ val numFeatures = 50
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000)
+ val rdd = sc.parallelize(arr)
+
+ // Select feature subset for top nodes. Return true if OK.
+ def checkFeatureSubsetStrategy(
+ numTrees: Int,
+ featureSubsetStrategy: String,
+ numFeaturesPerNode: Int): Unit = {
+ val seeds = Array(123, 5354, 230, 349867, 23987)
+ val maxMemoryUsage: Long = 128 * 1024L * 1024L
+ val metadata =
+ DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees, featureSubsetStrategy)
+ seeds.foreach { seed =>
+ val failString = s"Failed on test with:" +
+ s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," +
+ s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed"
+ val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
+ val topNodes: Array[LearningNode] = new Array[LearningNode](numTrees)
+ Range(0, numTrees).foreach { treeIndex =>
+ topNodes(treeIndex) = LearningNode.emptyNode(nodeIndex = 1)
+ nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))
+ }
+ val rng = new scala.util.Random(seed = seed)
+ val (nodesForGroup: Map[Int, Array[LearningNode]],
+ treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) =
+ RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
+
+ assert(nodesForGroup.size === numTrees, failString)
+ assert(nodesForGroup.values.forall(_.length == 1), failString) // 1 node per tree
+
+ if (numFeaturesPerNode == numFeatures) {
+ // featureSubset values should all be None
+ assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)),
+ failString)
+ } else {
+ // Check number of features.
+ assert(treeToNodeToIndexInfo.values.forall(_.values.forall(
+ _.featureSubset.get.length === numFeaturesPerNode)), failString)
+ }
+ }
+ }
+
+ checkFeatureSubsetStrategy(numTrees = 1, "auto", numFeatures)
+ checkFeatureSubsetStrategy(numTrees = 1, "all", numFeatures)
+ checkFeatureSubsetStrategy(numTrees = 1, "sqrt", math.sqrt(numFeatures).ceil.toInt)
+ checkFeatureSubsetStrategy(numTrees = 1, "log2",
+ (math.log(numFeatures) / math.log(2)).ceil.toInt)
+ checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt)
+
+ checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures)
+ checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt)
+ checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt)
+ checkFeatureSubsetStrategy(numTrees = 2, "log2",
+ (math.log(numFeatures) / math.log(2)).ceil.toInt)
+ checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt)
+ }
+
+ test("Binary classification with continuous features: subsampling features") {
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 2,
+ numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+ binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
+ }
+
+ test("Binary classification with continuous features and node Id cache: subsampling features") {
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 2,
+ numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
+ useNodeIdCache = true)
+ binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
+ }
+
test("computeFeatureImportance, featureImportances") {
/* Build tree for testing, with this structure:
grandParent
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 89b64fce96..bb1041b109 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -18,430 +18,23 @@
package org.apache.spark.mllib.tree
import scala.collection.JavaConverters._
-import scala.collection.mutable
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
-import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint}
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model._
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
/////////////////////////////////////////////////////////////////////////////
- // Tests examining individual elements of training
- /////////////////////////////////////////////////////////////////////////////
-
- test("Binary classification with continuous features: split and bin calculation") {
- val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
- assert(arr.length === 1000)
- val rdd = sc.parallelize(arr)
- val strategy = new Strategy(Classification, Gini, 3, 2, 100)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
- assert(!metadata.isUnordered(featureIndex = 0))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(splits.length === 2)
- assert(bins.length === 2)
- assert(splits(0).length === 99)
- assert(bins(0).length === 100)
- }
-
- test("Binary classification with binary (ordered) categorical features:" +
- " split and bin calculation") {
- val arr = DecisionTreeSuite.generateCategoricalDataPoints()
- assert(arr.length === 1000)
- val rdd = sc.parallelize(arr)
- val strategy = new Strategy(
- Classification,
- Gini,
- maxDepth = 2,
- numClasses = 2,
- maxBins = 100,
- categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2))
-
- val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(!metadata.isUnordered(featureIndex = 0))
- assert(!metadata.isUnordered(featureIndex = 1))
- assert(splits.length === 2)
- assert(bins.length === 2)
- // no bins or splits pre-computed for ordered categorical features
- assert(splits(0).length === 0)
- assert(bins(0).length === 0)
- }
-
- test("Binary classification with 3-ary (ordered) categorical features," +
- " with no samples for one category") {
- val arr = DecisionTreeSuite.generateCategoricalDataPoints()
- assert(arr.length === 1000)
- val rdd = sc.parallelize(arr)
- val strategy = new Strategy(
- Classification,
- Gini,
- maxDepth = 2,
- numClasses = 2,
- maxBins = 100,
- categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
-
- val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
- assert(!metadata.isUnordered(featureIndex = 0))
- assert(!metadata.isUnordered(featureIndex = 1))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(splits.length === 2)
- assert(bins.length === 2)
- // no bins or splits pre-computed for ordered categorical features
- assert(splits(0).length === 0)
- assert(bins(0).length === 0)
- }
-
- test("extract categories from a number for multiclass classification") {
- val l = DecisionTree.extractMultiClassCategories(13, 10)
- assert(l.length === 3)
- assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq)
- }
-
- test("find splits for a continuous feature") {
- // find splits for normal case
- {
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
- Map(), Set(),
- Array(6), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
- )
- val featureSamples = Array.fill(200000)(math.random)
- val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
- assert(splits.length === 5)
- assert(fakeMetadata.numSplits(0) === 5)
- assert(fakeMetadata.numBins(0) === 6)
- // check returned splits are distinct
- assert(splits.distinct.length === splits.length)
- }
-
- // find splits should not return identical splits
- // when there are not enough split candidates, reduce the number of splits in metadata
- {
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
- Map(), Set(),
- Array(5), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
- )
- val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
- val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
- assert(splits.length === 3)
- // check returned splits are distinct
- assert(splits.distinct.length === splits.length)
- }
-
- // find splits when most samples close to the minimum
- {
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
- Map(), Set(),
- Array(3), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
- )
- val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
- val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
- assert(splits.length === 2)
- assert(splits(0) === 2.0)
- assert(splits(1) === 3.0)
- }
-
- // find splits when most samples close to the maximum
- {
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
- Map(), Set(),
- Array(3), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
- )
- val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
- val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
- assert(splits.length === 1)
- assert(splits(0) === 1.0)
- }
- }
-
- test("Multiclass classification with unordered categorical features:" +
- " split and bin calculations") {
- val arr = DecisionTreeSuite.generateCategoricalDataPoints()
- assert(arr.length === 1000)
- val rdd = sc.parallelize(arr)
- val strategy = new Strategy(
- Classification,
- Gini,
- maxDepth = 2,
- numClasses = 100,
- maxBins = 100,
- categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
-
- val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
- assert(metadata.isUnordered(featureIndex = 0))
- assert(metadata.isUnordered(featureIndex = 1))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(splits.length === 2)
- assert(bins.length === 2)
- assert(splits(0).length === 3)
- assert(bins(0).length === 0)
- assert(metadata.numSplits(0) === 3)
- assert(metadata.numBins(0) === 3)
- assert(metadata.numSplits(1) === 3)
- assert(metadata.numBins(1) === 3)
-
- // Expecting 2^2 - 1 = 3 bins/splits
- assert(splits(0)(0).feature === 0)
- assert(splits(0)(0).threshold === Double.MinValue)
- assert(splits(0)(0).featureType === Categorical)
- assert(splits(0)(0).categories.length === 1)
- assert(splits(0)(0).categories.contains(0.0))
- assert(splits(1)(0).feature === 1)
- assert(splits(1)(0).threshold === Double.MinValue)
- assert(splits(1)(0).featureType === Categorical)
- assert(splits(1)(0).categories.length === 1)
- assert(splits(1)(0).categories.contains(0.0))
-
- assert(splits(0)(1).feature === 0)
- assert(splits(0)(1).threshold === Double.MinValue)
- assert(splits(0)(1).featureType === Categorical)
- assert(splits(0)(1).categories.length === 1)
- assert(splits(0)(1).categories.contains(1.0))
- assert(splits(1)(1).feature === 1)
- assert(splits(1)(1).threshold === Double.MinValue)
- assert(splits(1)(1).featureType === Categorical)
- assert(splits(1)(1).categories.length === 1)
- assert(splits(1)(1).categories.contains(1.0))
-
- assert(splits(0)(2).feature === 0)
- assert(splits(0)(2).threshold === Double.MinValue)
- assert(splits(0)(2).featureType === Categorical)
- assert(splits(0)(2).categories.length === 2)
- assert(splits(0)(2).categories.contains(0.0))
- assert(splits(0)(2).categories.contains(1.0))
- assert(splits(1)(2).feature === 1)
- assert(splits(1)(2).threshold === Double.MinValue)
- assert(splits(1)(2).featureType === Categorical)
- assert(splits(1)(2).categories.length === 2)
- assert(splits(1)(2).categories.contains(0.0))
- assert(splits(1)(2).categories.contains(1.0))
-
- }
-
- test("Multiclass classification with ordered categorical features: split and bin calculations") {
- val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
- assert(arr.length === 3000)
- val rdd = sc.parallelize(arr)
- val strategy = new Strategy(
- Classification,
- Gini,
- maxDepth = 2,
- numClasses = 100,
- maxBins = 100,
- categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
- // 2^(10-1) - 1 > 100, so categorical features will be ordered
-
- val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
- assert(!metadata.isUnordered(featureIndex = 0))
- assert(!metadata.isUnordered(featureIndex = 1))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(splits.length === 2)
- assert(bins.length === 2)
- // no bins or splits pre-computed for ordered categorical features
- assert(splits(0).length === 0)
- assert(bins(0).length === 0)
- }
-
- test("Avoid aggregation on the last level") {
- val arr = Array(
- LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)),
- LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
- LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
- LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
- val input = sc.parallelize(arr)
-
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
- numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
- val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
-
- val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
-
- val topNode = Node.emptyNode(nodeIndex = 1)
- assert(topNode.predict.predict === Double.MinValue)
- assert(topNode.impurity === -1.0)
- assert(topNode.isLeaf === false)
-
- val nodesForGroup = Map((0, Array(topNode)))
- val treeToNodeToIndexInfo = Map((0, Map(
- (topNode.id, new RandomForest.NodeIndexInfo(0, None))
- )))
- val nodeQueue = new mutable.Queue[(Int, Node)]()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
- nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
-
- // don't enqueue leaf nodes into node queue
- assert(nodeQueue.isEmpty)
-
- // set impurity and predict for topNode
- assert(topNode.predict.predict !== Double.MinValue)
- assert(topNode.impurity !== -1.0)
-
- // set impurity and predict for child nodes
- assert(topNode.leftNode.get.predict.predict === 0.0)
- assert(topNode.rightNode.get.predict.predict === 1.0)
- assert(topNode.leftNode.get.impurity === 0.0)
- assert(topNode.rightNode.get.impurity === 0.0)
- }
-
- test("Avoid aggregation if impurity is 0.0") {
- val arr = Array(
- LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)),
- LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
- LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
- LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
- val input = sc.parallelize(arr)
-
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
- numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
- val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
-
- val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
-
- val topNode = Node.emptyNode(nodeIndex = 1)
- assert(topNode.predict.predict === Double.MinValue)
- assert(topNode.impurity === -1.0)
- assert(topNode.isLeaf === false)
-
- val nodesForGroup = Map((0, Array(topNode)))
- val treeToNodeToIndexInfo = Map((0, Map(
- (topNode.id, new RandomForest.NodeIndexInfo(0, None))
- )))
- val nodeQueue = new mutable.Queue[(Int, Node)]()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
- nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
-
- // don't enqueue a node into node queue if its impurity is 0.0
- assert(nodeQueue.isEmpty)
-
- // set impurity and predict for topNode
- assert(topNode.predict.predict !== Double.MinValue)
- assert(topNode.impurity !== -1.0)
-
- // set impurity and predict for child nodes
- assert(topNode.leftNode.get.predict.predict === 0.0)
- assert(topNode.rightNode.get.predict.predict === 1.0)
- assert(topNode.leftNode.get.impurity === 0.0)
- assert(topNode.rightNode.get.impurity === 0.0)
- }
-
- test("Use soft prediction for binary classification with ordered categorical features") {
- // The following dataset is set up such that the best split is {1} vs. {0, 2}.
- // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
- val arr = Array(
- LabeledPoint(0.0, Vectors.dense(0.0)),
- LabeledPoint(0.0, Vectors.dense(0.0)),
- LabeledPoint(0.0, Vectors.dense(0.0)),
- LabeledPoint(1.0, Vectors.dense(0.0)),
- LabeledPoint(0.0, Vectors.dense(1.0)),
- LabeledPoint(0.0, Vectors.dense(1.0)),
- LabeledPoint(0.0, Vectors.dense(1.0)),
- LabeledPoint(0.0, Vectors.dense(1.0)),
- LabeledPoint(0.0, Vectors.dense(2.0)),
- LabeledPoint(0.0, Vectors.dense(2.0)),
- LabeledPoint(0.0, Vectors.dense(2.0)),
- LabeledPoint(1.0, Vectors.dense(2.0)))
- val input = sc.parallelize(arr)
-
- // Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
- numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
-
- val model = new DecisionTree(strategy).run(input)
- model.topNode.split.get match {
- case Split(_, _, _, categories: List[Double]) =>
- assert(categories === List(1.0))
- }
- }
-
- test("Second level node building with vs. without groups") {
- val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
- assert(arr.length === 1000)
- val rdd = sc.parallelize(arr)
- val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(splits.length === 2)
- assert(splits(0).length === 99)
- assert(bins.length === 2)
- assert(bins(0).length === 100)
-
- // Train a 1-node model
- val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1,
- numClasses = 2, maxBins = 100)
- val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
- val rootNode1 = modelOneNode.topNode.deepCopy()
- val rootNode2 = modelOneNode.topNode.deepCopy()
- assert(rootNode1.leftNode.nonEmpty)
- assert(rootNode1.rightNode.nonEmpty)
-
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
-
- // Single group second level tree construction.
- val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get)))
- val treeToNodeToIndexInfo = Map((0, Map(
- (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)),
- (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None)))))
- val nodeQueue = new mutable.Queue[(Int, Node)]()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1),
- nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
- val children1 = new Array[Node](2)
- children1(0) = rootNode1.leftNode.get
- children1(1) = rootNode1.rightNode.get
-
- // Train one second-level node at a time.
- val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get)))
- val treeToNodeToIndexInfoA = Map((0, Map(
- (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
- nodeQueue.clear()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
- nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue)
- val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get)))
- val treeToNodeToIndexInfoB = Map((0, Map(
- (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
- nodeQueue.clear()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
- nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue)
- val children2 = new Array[Node](2)
- children2(0) = rootNode2.leftNode.get
- children2(1) = rootNode2.rightNode.get
-
- // Verify whether the splits obtained using single group and multiple group level
- // construction strategies are the same.
- for (i <- 0 until 2) {
- assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0)
- assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0)
- assert(children1(i).split === children2(i).split)
- assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty)
- val stats1 = children1(i).stats.get
- val stats2 = children2(i).stats.get
- assert(stats1.gain === stats2.gain)
- assert(stats1.impurity === stats2.impurity)
- assert(stats1.leftImpurity === stats2.leftImpurity)
- assert(stats1.rightImpurity === stats2.rightImpurity)
- assert(children1(i).predict.predict === children2(i).predict.predict)
- }
- }
-
- /////////////////////////////////////////////////////////////////////////////
// Tests calling train()
/////////////////////////////////////////////////////////////////////////////
@@ -457,22 +50,11 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
- val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
- assert(!metadata.isUnordered(featureIndex = 0))
- assert(!metadata.isUnordered(featureIndex = 1))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(splits.length === 2)
- assert(bins.length === 2)
- // no bins or splits pre-computed for ordered categorical features
- assert(splits(0).length === 0)
- assert(bins(0).length === 0)
-
val rootNode = DecisionTree.train(rdd, strategy).topNode
val split = rootNode.split.get
assert(split.categories === List(1.0))
assert(split.featureType === Categorical)
- assert(split.threshold === Double.MinValue)
val stats = rootNode.stats.get
assert(stats.gain > 0)
@@ -501,7 +83,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(split.categories.length === 1)
assert(split.categories.contains(1.0))
assert(split.featureType === Categorical)
- assert(split.threshold === Double.MinValue)
val stats = rootNode.stats.get
assert(stats.gain > 0)
@@ -539,18 +120,11 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(splits.length === 2)
- assert(splits(0).length === 99)
- assert(bins.length === 2)
- assert(bins(0).length === 100)
-
val rootNode = DecisionTree.train(rdd, strategy).topNode
- val stats = rootNode.stats.get
- assert(stats.gain === 0)
- assert(stats.leftImpurity === 0)
- assert(stats.rightImpurity === 0)
+ assert(rootNode.impurity === 0)
+ assert(rootNode.stats.isEmpty)
+ assert(rootNode.predict.predict === 0)
}
test("Binary classification stump with fixed label 1 for Gini") {
@@ -563,18 +137,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(splits.length === 2)
- assert(splits(0).length === 99)
- assert(bins.length === 2)
- assert(bins(0).length === 100)
-
val rootNode = DecisionTree.train(rdd, strategy).topNode
- val stats = rootNode.stats.get
- assert(stats.gain === 0)
- assert(stats.leftImpurity === 0)
- assert(stats.rightImpurity === 0)
+ assert(rootNode.impurity === 0)
+ assert(rootNode.stats.isEmpty)
assert(rootNode.predict.predict === 1)
}
@@ -588,18 +154,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(splits.length === 2)
- assert(splits(0).length === 99)
- assert(bins.length === 2)
- assert(bins(0).length === 100)
-
val rootNode = DecisionTree.train(rdd, strategy).topNode
- val stats = rootNode.stats.get
- assert(stats.gain === 0)
- assert(stats.leftImpurity === 0)
- assert(stats.rightImpurity === 0)
+ assert(rootNode.impurity === 0)
+ assert(rootNode.stats.isEmpty)
assert(rootNode.predict.predict === 0)
}
@@ -613,18 +171,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(splits.length === 2)
- assert(splits(0).length === 99)
- assert(bins.length === 2)
- assert(bins(0).length === 100)
-
val rootNode = DecisionTree.train(rdd, strategy).topNode
- val stats = rootNode.stats.get
- assert(stats.gain === 0)
- assert(stats.leftImpurity === 0)
- assert(stats.rightImpurity === 0)
+ assert(rootNode.impurity === 0)
+ assert(rootNode.stats.isEmpty)
assert(rootNode.predict.predict === 1)
}
@@ -718,7 +268,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClasses = 3, maxBins = 100)
assert(strategy.isMulticlassClassification)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
val model = DecisionTree.train(rdd, strategy)
DecisionTreeSuite.validateClassifier(model, arr, 0.9)
@@ -807,8 +356,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
// test when no valid split can be found
val rootNode = model.topNode
- val gain = rootNode.stats.get
- assert(gain == InformationGainStats.invalidInformationGainStats)
+ assert(rootNode.stats.isEmpty)
}
test("do not choose split that does not satisfy min instance per node requirements") {
@@ -828,9 +376,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val rootNode = DecisionTree.train(rdd, strategy).topNode
val split = rootNode.split.get
- val gain = rootNode.stats.get
+ val gainStats = rootNode.stats.get
assert(split.feature == 1)
- assert(gain != InformationGainStats.invalidInformationGainStats)
+ assert(gainStats.gain >= 0)
+ assert(gainStats.impurity >= 0)
}
test("split must satisfy min info gain requirements") {
@@ -852,10 +401,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
}
// test when no valid split can be found
- val rootNode = model.topNode
-
- val gain = rootNode.stats.get
- assert(gain == InformationGainStats.invalidInformationGainStats)
+ assert(model.topNode.stats.isEmpty)
}
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index c72fc9bb4f..bec61ba6a0 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -17,16 +17,13 @@
package org.apache.spark.mllib.tree
-import scala.collection.mutable
-
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
import org.apache.spark.mllib.tree.impurity.{Gini, Variance}
-import org.apache.spark.mllib.tree.model.{Node, RandomForestModel}
+import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.util.Utils
@@ -42,7 +39,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees,
featureSubsetStrategy = "auto", seed = 123)
- assert(rf.trees.size === 1)
+ assert(rf.trees.length === 1)
val rfTree = rf.trees(0)
val dt = DecisionTree.train(rdd, strategy)
@@ -78,7 +75,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees,
featureSubsetStrategy = "auto", seed = 123)
- assert(rf.trees.size === 1)
+ assert(rf.trees.length === 1)
val rfTree = rf.trees(0)
val dt = DecisionTree.train(rdd, strategy)
@@ -108,80 +105,6 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
regressionTestWithContinuousFeatures(strategy)
}
- def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: Strategy) {
- val numFeatures = 50
- val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000)
- val rdd = sc.parallelize(arr)
-
- // Select feature subset for top nodes. Return true if OK.
- def checkFeatureSubsetStrategy(
- numTrees: Int,
- featureSubsetStrategy: String,
- numFeaturesPerNode: Int): Unit = {
- val seeds = Array(123, 5354, 230, 349867, 23987)
- val maxMemoryUsage: Long = 128 * 1024L * 1024L
- val metadata =
- DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees, featureSubsetStrategy)
- seeds.foreach { seed =>
- val failString = s"Failed on test with:" +
- s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," +
- s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed"
- val nodeQueue = new mutable.Queue[(Int, Node)]()
- val topNodes: Array[Node] = new Array[Node](numTrees)
- Range(0, numTrees).foreach { treeIndex =>
- topNodes(treeIndex) = Node.emptyNode(nodeIndex = 1)
- nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))
- }
- val rng = new scala.util.Random(seed = seed)
- val (nodesForGroup: Map[Int, Array[Node]],
- treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) =
- RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
-
- assert(nodesForGroup.size === numTrees, failString)
- assert(nodesForGroup.values.forall(_.size == 1), failString) // 1 node per tree
-
- if (numFeaturesPerNode == numFeatures) {
- // featureSubset values should all be None
- assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)),
- failString)
- } else {
- // Check number of features.
- assert(treeToNodeToIndexInfo.values.forall(_.values.forall(
- _.featureSubset.get.size === numFeaturesPerNode)), failString)
- }
- }
- }
-
- checkFeatureSubsetStrategy(numTrees = 1, "auto", numFeatures)
- checkFeatureSubsetStrategy(numTrees = 1, "all", numFeatures)
- checkFeatureSubsetStrategy(numTrees = 1, "sqrt", math.sqrt(numFeatures).ceil.toInt)
- checkFeatureSubsetStrategy(numTrees = 1, "log2",
- (math.log(numFeatures) / math.log(2)).ceil.toInt)
- checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt)
-
- checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures)
- checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt)
- checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt)
- checkFeatureSubsetStrategy(numTrees = 2, "log2",
- (math.log(numFeatures) / math.log(2)).ceil.toInt)
- checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt)
- }
-
- test("Binary classification with continuous features: subsampling features") {
- val categoricalFeaturesInfo = Map.empty[Int, Int]
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
- numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
- binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
- }
-
- test("Binary classification with continuous features and node Id cache: subsampling features") {
- val categoricalFeaturesInfo = Map.empty[Int, Int]
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
- numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
- useNodeIdCache = true)
- binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
- }
-
test("alternating categorical and continuous features with multiclass labels to test indexing") {
val arr = new Array[LabeledPoint](4)
arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0))
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index 7dd2937db7..715fa9e9f8 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -164,7 +164,8 @@ if __name__ == "__main__":
"split will be discarded as invalid. Should be >= 1.", "TypeConverters.toInt"),
("minInfoGain", "Minimum information gain for a split to be considered at a tree node.",
"TypeConverters.toFloat"),
- ("maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.",
+ ("maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation. If too small," +
+ " then 1 node will be split per iteration, and its aggregates may exceed this size.",
"TypeConverters.toInt"),
("cacheNodeIds", "If false, the algorithm will pass trees to executors to match " +
"instances with nodes. If true, the algorithm will cache node IDs for each instance. " +
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index 83fbd59039..d79d55e463 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -568,7 +568,7 @@ class DecisionTreeParams(Params):
maxBins = Param(Params._dummy(), "maxBins", "Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.", typeConverter=TypeConverters.toInt)
minInstancesPerNode = Param(Params._dummy(), "minInstancesPerNode", "Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.", typeConverter=TypeConverters.toInt)
minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.", typeConverter=TypeConverters.toFloat)
- maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.", typeConverter=TypeConverters.toInt)
+ maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.", typeConverter=TypeConverters.toInt)
cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.", typeConverter=TypeConverters.toBoolean)