aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala191
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala37
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala277
7 files changed, 268 insertions, 256 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 98596569b8..56bb881210 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -87,17 +87,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
val maxDepth = strategy.maxDepth
require(maxDepth <= 30,
s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
- // Number of nodes to allocate: max number of nodes possible given the depth of the tree, plus 1
- val maxNumNodesPlus1 = Node.startIndexInLevel(maxDepth + 1)
- // Initialize an array to hold parent impurity calculations for each node.
- val parentImpurities = new Array[Double](maxNumNodesPlus1)
- // dummy value for top node (updated during first split calculation)
- val nodes = new Array[Node](maxNumNodesPlus1)
// Calculate level for single group construction
// Max memory usage for aggregates
- val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024
+ val maxMemoryUsage = strategy.maxMemoryInMB * 1024L * 1024L
logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
// TODO: Calculate memory usage more precisely.
val numElementsPerNode = DecisionTree.getElementsPerNode(metadata)
@@ -120,81 +114,35 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
* beforehand and is not used in later levels.
*/
+ var topNode: Node = null // set on first iteration
var level = 0
var break = false
while (level <= maxDepth && !break) {
-
logDebug("#####################################")
logDebug("level = " + level)
logDebug("#####################################")
// Find best split for all nodes at a level.
timer.start("findBestSplits")
- val splitsStatsForLevel: Array[(Split, InformationGainStats, Predict)] =
- DecisionTree.findBestSplits(treeInput, parentImpurities,
- metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
+ val (tmpTopNode: Node, doneTraining: Boolean) = DecisionTree.findBestSplits(treeInput,
+ metadata, level, topNode, splits, bins, maxLevelForSingleGroup, timer)
timer.stop("findBestSplits")
- val levelNodeIndexOffset = Node.startIndexInLevel(level)
- for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
- val nodeIndex = levelNodeIndexOffset + index
-
- // Extract info for this node (index) at the current level.
- timer.start("extractNodeInfo")
- val split = nodeSplitStats._1
- val stats = nodeSplitStats._2
- val predict = nodeSplitStats._3.predict
- val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
- val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats))
- logDebug("Node = " + node)
- nodes(nodeIndex) = node
- timer.stop("extractNodeInfo")
-
- if (level != 0) {
- // Set parent.
- val parentNodeIndex = Node.parentIndex(nodeIndex)
- if (Node.isLeftChild(nodeIndex)) {
- nodes(parentNodeIndex).leftNode = Some(nodes(nodeIndex))
- } else {
- nodes(parentNodeIndex).rightNode = Some(nodes(nodeIndex))
- }
- }
- // Extract info for nodes at the next lower level.
- timer.start("extractInfoForLowerLevels")
- if (level < maxDepth) {
- val leftChildIndex = Node.leftChildIndex(nodeIndex)
- val leftImpurity = stats.leftImpurity
- logDebug("leftChildIndex = " + leftChildIndex + ", impurity = " + leftImpurity)
- parentImpurities(leftChildIndex) = leftImpurity
-
- val rightChildIndex = Node.rightChildIndex(nodeIndex)
- val rightImpurity = stats.rightImpurity
- logDebug("rightChildIndex = " + rightChildIndex + ", impurity = " + rightImpurity)
- parentImpurities(rightChildIndex) = rightImpurity
- }
- timer.stop("extractInfoForLowerLevels")
- logDebug("final best split = " + split)
+ if (level == 0) {
+ topNode = tmpTopNode
}
- require(Node.maxNodesInLevel(level) == splitsStatsForLevel.length)
- // Check whether all the nodes at the current level at leaves.
- val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
- logDebug("all leaf = " + allLeaf)
- if (allLeaf) {
- break = true // no more tree construction
- } else {
- level += 1
+ if (doneTraining) {
+ break = true
+ logDebug("done training")
}
+
+ level += 1
}
logDebug("#####################################")
logDebug("Extracting tree model")
logDebug("#####################################")
- // Initialize the top or root node of the tree.
- val topNode = nodes(1)
- // Build the full tree using the node info calculated in the level-wise best split calculations.
- topNode.build(nodes)
-
timer.stop("total")
logInfo("Internal timing for DecisionTree:")
@@ -409,24 +357,26 @@ object DecisionTree extends Serializable with Logging {
* multiple groups if the level-wise training task could lead to memory overflow.
*
* @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
- * @param parentImpurities Impurities for all parent nodes for the current level
* @param metadata Learning and dataset metadata
* @param level Level of the tree
+ * @param topNode Root node of the tree (or invalid node when training first level).
* @param splits possible splits for all features, indexed (numFeatures)(numSplits)
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
* @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
- * @return array (over nodes) of splits with best split for each node at a given level.
+ * @return (root, doneTraining) where:
+ * root = Root node (which is newly created on the first iteration),
+ * doneTraining = true if no more internal nodes were created.
*/
private[tree] def findBestSplits(
input: RDD[TreePoint],
- parentImpurities: Array[Double],
metadata: DecisionTreeMetadata,
level: Int,
- nodes: Array[Node],
+ topNode: Node,
splits: Array[Array[Split]],
bins: Array[Array[Bin]],
maxLevelForSingleGroup: Int,
- timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats, Predict)] = {
+ timer: TimeTracker = new TimeTracker): (Node, Boolean) = {
+
// split into groups to avoid memory overflow during aggregation
if (level > maxLevelForSingleGroup) {
// When information for all nodes at a given level cannot be stored in memory,
@@ -435,18 +385,18 @@ object DecisionTree extends Serializable with Logging {
// numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
val numGroups = 1 << level - maxLevelForSingleGroup
logDebug("numGroups = " + numGroups)
- var bestSplits = new Array[(Split, InformationGainStats, Predict)](0)
// Iterate over each group of nodes at a level.
var groupIndex = 0
+ var doneTraining = true
while (groupIndex < numGroups) {
- val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, metadata, level,
- nodes, splits, bins, timer, numGroups, groupIndex)
- bestSplits = Array.concat(bestSplits, bestSplitsForGroup)
+ val (tmpRoot, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level,
+ topNode, splits, bins, timer, numGroups, groupIndex)
+ doneTraining = doneTraining && doneTrainingGroup
groupIndex += 1
}
- bestSplits
+ (topNode, doneTraining) // Not first iteration, so topNode was already set.
} else {
- findBestSplitsPerGroup(input, parentImpurities, metadata, level, nodes, splits, bins, timer)
+ findBestSplitsPerGroup(input, metadata, level, topNode, splits, bins, timer)
}
}
@@ -586,27 +536,27 @@ object DecisionTree extends Serializable with Logging {
* Returns an array of optimal splits for a group of nodes at a given level
*
* @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
- * @param parentImpurities Impurities for all parent nodes for the current level
* @param metadata Learning and dataset metadata
* @param level Level of the tree
- * @param nodes Array of all nodes in the tree. Used for matching data points to nodes.
+ * @param topNode Root node of the tree (or invalid node when training first level).
* @param splits possible splits for all features, indexed (numFeatures)(numSplits)
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
* @param numGroups total number of node groups at the current level. Default value is set to 1.
* @param groupIndex index of the node group being processed. Default value is set to 0.
- * @return array of splits with best splits for all nodes at a given level.
+ * @return (root, doneTraining) where:
+ * root = Root node (which is newly created on the first iteration),
+ * doneTraining = true if no more internal nodes were created.
*/
private def findBestSplitsPerGroup(
input: RDD[TreePoint],
- parentImpurities: Array[Double],
metadata: DecisionTreeMetadata,
level: Int,
- nodes: Array[Node],
+ topNode: Node,
splits: Array[Array[Split]],
bins: Array[Array[Bin]],
timer: TimeTracker,
numGroups: Int = 1,
- groupIndex: Int = 0): Array[(Split, InformationGainStats, Predict)] = {
+ groupIndex: Int = 0): (Node, Boolean) = {
/*
* The high-level descriptions of the best split optimizations are noted here.
@@ -663,7 +613,7 @@ object DecisionTree extends Serializable with Logging {
0
} else {
val globalNodeIndex =
- predictNodeIndex(nodes(1), treePoint.binnedFeatures, bins, metadata.unorderedFeatures)
+ predictNodeIndex(topNode, treePoint.binnedFeatures, bins, metadata.unorderedFeatures)
globalNodeIndex - globalNodeIndexOffset
}
}
@@ -706,33 +656,63 @@ object DecisionTree extends Serializable with Logging {
// Calculate best splits for all nodes at a given level
timer.start("chooseSplits")
- val bestSplits = new Array[(Split, InformationGainStats, Predict)](numNodes)
- // Iterating over all nodes at this level
+ // On the first iteration, we need to get and return the newly created root node.
+ var newTopNode: Node = topNode
+
+ // Iterate over all nodes at this level
var nodeIndex = 0
+ var internalNodeCount = 0
while (nodeIndex < numNodes) {
- val nodeImpurity = parentImpurities(globalNodeIndexOffset + nodeIndex)
- logDebug("node impurity = " + nodeImpurity)
- bestSplits(nodeIndex) =
- binsToBestSplit(binAggregates, nodeIndex, nodeImpurity, level, metadata, splits)
- logDebug("best split = " + bestSplits(nodeIndex)._1)
+ val (split: Split, stats: InformationGainStats, predict: Predict) =
+ binsToBestSplit(binAggregates, nodeIndex, level, metadata, splits)
+ logDebug("best split = " + split)
+
+ val globalNodeIndex = globalNodeIndexOffset + nodeIndex
+
+ // Extract info for this node at the current level.
+ val isLeaf = (stats.gain <= 0) || (level == metadata.maxDepth)
+ val node =
+ new Node(globalNodeIndex, predict.predict, isLeaf, Some(split), None, None, Some(stats))
+ logDebug("Node = " + node)
+
+ if (!isLeaf) {
+ internalNodeCount += 1
+ }
+ if (level == 0) {
+ newTopNode = node
+ } else {
+ // Set parent.
+ val parentNode = Node.getNode(Node.parentIndex(globalNodeIndex), topNode)
+ if (Node.isLeftChild(globalNodeIndex)) {
+ parentNode.leftNode = Some(node)
+ } else {
+ parentNode.rightNode = Some(node)
+ }
+ }
+ if (level < metadata.maxDepth) {
+ logDebug("leftChildIndex = " + Node.leftChildIndex(globalNodeIndex) +
+ ", impurity = " + stats.leftImpurity)
+ logDebug("rightChildIndex = " + Node.rightChildIndex(globalNodeIndex) +
+ ", impurity = " + stats.rightImpurity)
+ }
+
nodeIndex += 1
}
timer.stop("chooseSplits")
- bestSplits
+ val doneTraining = internalNodeCount == 0
+ (newTopNode, doneTraining)
}
/**
* Calculate the information gain for a given (feature, split) based upon left/right aggregates.
* @param leftImpurityCalculator left node aggregates for this (feature, split)
* @param rightImpurityCalculator right node aggregate for this (feature, split)
- * @param topImpurity impurity of the parent node
* @return information gain and statistics for all splits
*/
private def calculateGainForSplit(
leftImpurityCalculator: ImpurityCalculator,
rightImpurityCalculator: ImpurityCalculator,
- topImpurity: Double,
level: Int,
metadata: DecisionTreeMetadata): InformationGainStats = {
val leftCount = leftImpurityCalculator.count
@@ -747,14 +727,10 @@ object DecisionTree extends Serializable with Logging {
val totalCount = leftCount + rightCount
- // impurity of parent node
- val impurity = if (level > 0) {
- topImpurity
- } else {
- val parentNodeAgg = leftImpurityCalculator.copy
- parentNodeAgg.add(rightImpurityCalculator)
- parentNodeAgg.calculate()
- }
+ val parentNodeAgg = leftImpurityCalculator.copy
+ parentNodeAgg.add(rightImpurityCalculator)
+
+ val impurity = parentNodeAgg.calculate()
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
val rightImpurity = rightImpurityCalculator.calculate()
@@ -795,19 +771,15 @@ object DecisionTree extends Serializable with Logging {
* Find the best split for a node.
* @param binAggregates Bin statistics.
* @param nodeIndex Index for node to split in this (level, group).
- * @param nodeImpurity Impurity of the node (nodeIndex).
* @return tuple for best split: (Split, information gain)
*/
private def binsToBestSplit(
binAggregates: DTStatsAggregator,
nodeIndex: Int,
- nodeImpurity: Double,
level: Int,
metadata: DecisionTreeMetadata,
splits: Array[Array[Split]]): (Split, InformationGainStats, Predict) = {
- logDebug("node impurity = " + nodeImpurity)
-
// calculate predict only once
var predict: Option[Predict] = None
@@ -831,8 +803,7 @@ object DecisionTree extends Serializable with Logging {
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
- val gainStats =
- calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
+ val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata)
(splitIdx, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
@@ -845,8 +816,7 @@ object DecisionTree extends Serializable with Logging {
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
- val gainStats =
- calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
+ val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
@@ -917,8 +887,7 @@ object DecisionTree extends Serializable with Logging {
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
- val gainStats =
- calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
+ val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
val categoriesForSplit =
@@ -937,8 +906,8 @@ object DecisionTree extends Serializable with Logging {
/**
* Get the number of values to be stored per node in the bin aggregates.
*/
- private def getElementsPerNode(metadata: DecisionTreeMetadata): Int = {
- val totalBins = metadata.numBins.sum
+ private def getElementsPerNode(metadata: DecisionTreeMetadata): Long = {
+ val totalBins = metadata.numBins.map(_.toLong).sum
if (metadata.isClassification) {
metadata.numClasses * totalBins
} else {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 987fe632c9..31d1e8ac30 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -75,6 +75,9 @@ class Strategy (
if (algo == Classification) {
require(numClassesForClassification >= 2)
}
+ require(minInstancesPerNode >= 1,
+ s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
+
val isMulticlassClassification =
algo == Classification && numClassesForClassification > 2
val isMulticlassWithCategoricalFeatures
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
index 866d85a79b..61a9424671 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
@@ -65,14 +65,7 @@ private[tree] class DTStatsAggregator(
* Offset for each feature for calculating indices into the [[allStats]] array.
*/
private val featureOffsets: Array[Int] = {
- def featureOffsetsCalc(total: Int, featureIndex: Int): Int = {
- if (isUnordered(featureIndex)) {
- total + 2 * numBins(featureIndex)
- } else {
- total + numBins(featureIndex)
- }
- }
- Range(0, numFeatures).scanLeft(0)(featureOffsetsCalc).map(statsSize * _).toArray
+ numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
}
/**
@@ -149,7 +142,7 @@ private[tree] class DTStatsAggregator(
s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," +
s" but was called for ordered feature $featureIndex.")
val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex)
- (baseOffset, baseOffset + numBins(featureIndex) * statsSize)
+ (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index 5ceaa8154d..b6d49e5555 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -46,6 +46,7 @@ private[tree] class DecisionTreeMetadata(
val numBins: Array[Int],
val impurity: Impurity,
val quantileStrategy: QuantileStrategy,
+ val maxDepth: Int,
val minInstancesPerNode: Int,
val minInfoGain: Double) extends Serializable {
@@ -129,7 +130,7 @@ private[tree] object DecisionTreeMetadata {
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
- strategy.impurity, strategy.quantileCalculationStrategy,
+ strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
strategy.minInstancesPerNode, strategy.minInfoGain)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index 0594fd0749..271b2c4ad8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -46,7 +46,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
* Predict values for the given data set using the model trained.
*
* @param features RDD representing data points to be predicted
- * @return RDD[Int] where each entry contains the corresponding prediction
+ * @return RDD of predictions for each of the given data points
*/
def predict(features: RDD[Vector]): RDD[Double] = {
features.map(x => predict(x))
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 5b8a4cbed2..5f0095d23c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -55,6 +55,8 @@ class Node (
* build the left node and right nodes if not leaf
* @param nodes array of nodes
*/
+ @deprecated("build should no longer be used since trees are constructed on-the-fly in training",
+ "1.2.0")
def build(nodes: Array[Node]): Unit = {
logDebug("building node " + id + " at level " + Node.indexToLevel(id))
logDebug("id = " + id + ", split = " + split)
@@ -94,6 +96,23 @@ class Node (
}
/**
+ * Returns a deep copy of the subtree rooted at this node.
+ */
+ private[tree] def deepCopy(): Node = {
+ val leftNodeCopy = if (leftNode.isEmpty) {
+ None
+ } else {
+ Some(leftNode.get.deepCopy())
+ }
+ val rightNodeCopy = if (rightNode.isEmpty) {
+ None
+ } else {
+ Some(rightNode.get.deepCopy())
+ }
+ new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats)
+ }
+
+ /**
* Get the number of nodes in tree below this node, including leaf nodes.
* E.g., if this is a leaf, returns 0. If both children are leaves, returns 2.
*/
@@ -190,4 +209,22 @@ private[tree] object Node {
*/
def startIndexInLevel(level: Int): Int = 1 << level
+ /**
+ * Traces down from a root node to get the node with the given node index.
+ * This assumes the node exists.
+ */
+ def getNode(nodeIndex: Int, rootNode: Node): Node = {
+ var tmpNode: Node = rootNode
+ var levelsToGo = indexToLevel(nodeIndex)
+ while (levelsToGo > 0) {
+ if ((nodeIndex & (1 << levelsToGo - 1)) == 0) {
+ tmpNode = tmpNode.leftNode.get
+ } else {
+ tmpNode = tmpNode.rightNode.get
+ }
+ levelsToGo -= 1
+ }
+ tmpNode
+ }
+
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index fd8547c166..1bd7ea05c4 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -270,19 +270,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins(0).length === 0)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
- new Array[Node](0), splits, bins, 10)
+ val (rootNode: Node, doneTraining: Boolean) =
+ DecisionTree.findBestSplits(treeInput, metadata, 0, null, splits, bins, 10)
- val split = bestSplits(0)._1
+ val split = rootNode.split.get
assert(split.categories === List(1.0))
assert(split.featureType === Categorical)
assert(split.threshold === Double.MinValue)
- val stats = bestSplits(0)._2
- val predict = bestSplits(0)._3
+ val stats = rootNode.stats.get
assert(stats.gain > 0)
- assert(predict.predict === 1)
- assert(predict.prob === 0.6)
+ assert(rootNode.predict === 1)
assert(stats.impurity > 0.2)
}
@@ -303,19 +301,18 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
- new Array[Node](0), splits, bins, 10)
+ val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+ null, splits, bins, 10)
- val split = bestSplits(0)._1
+ val split = rootNode.split.get
assert(split.categories.length === 1)
assert(split.categories.contains(1.0))
assert(split.featureType === Categorical)
assert(split.threshold === Double.MinValue)
- val stats = bestSplits(0)._2
- val predict = bestSplits(0)._3.predict
+ val stats = rootNode.stats.get
assert(stats.gain > 0)
- assert(predict === 0.6)
+ assert(rootNode.predict === 0.6)
assert(stats.impurity > 0.2)
}
@@ -356,13 +353,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins(0).length === 100)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
- new Array[Node](0), splits, bins, 10)
- assert(bestSplits.length === 1)
- assert(bestSplits(0)._1.feature === 0)
- assert(bestSplits(0)._2.gain === 0)
- assert(bestSplits(0)._2.leftImpurity === 0)
- assert(bestSplits(0)._2.rightImpurity === 0)
+ val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+ null, splits, bins, 10)
+
+ val split = rootNode.split.get
+ assert(split.feature === 0)
+
+ val stats = rootNode.stats.get
+ assert(stats.gain === 0)
+ assert(stats.leftImpurity === 0)
+ assert(stats.rightImpurity === 0)
}
test("Binary classification stump with fixed label 1 for Gini") {
@@ -382,14 +382,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins(0).length === 100)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
- new Array[Node](0), splits, bins, 10)
- assert(bestSplits.length === 1)
- assert(bestSplits(0)._1.feature === 0)
- assert(bestSplits(0)._2.gain === 0)
- assert(bestSplits(0)._2.leftImpurity === 0)
- assert(bestSplits(0)._2.rightImpurity === 0)
- assert(bestSplits(0)._3.predict === 1)
+ val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+ null, splits, bins, 10)
+
+ val split = rootNode.split.get
+ assert(split.feature === 0)
+
+ val stats = rootNode.stats.get
+ assert(stats.gain === 0)
+ assert(stats.leftImpurity === 0)
+ assert(stats.rightImpurity === 0)
+ assert(rootNode.predict === 1)
}
test("Binary classification stump with fixed label 0 for Entropy") {
@@ -409,14 +412,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins(0).length === 100)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
- new Array[Node](0), splits, bins, 10)
- assert(bestSplits.length === 1)
- assert(bestSplits(0)._1.feature === 0)
- assert(bestSplits(0)._2.gain === 0)
- assert(bestSplits(0)._2.leftImpurity === 0)
- assert(bestSplits(0)._2.rightImpurity === 0)
- assert(bestSplits(0)._3.predict === 0)
+ val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+ null, splits, bins, 10)
+
+ val split = rootNode.split.get
+ assert(split.feature === 0)
+
+ val stats = rootNode.stats.get
+ assert(stats.gain === 0)
+ assert(stats.leftImpurity === 0)
+ assert(stats.rightImpurity === 0)
+ assert(rootNode.predict === 0)
}
test("Binary classification stump with fixed label 1 for Entropy") {
@@ -436,14 +442,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins(0).length === 100)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
- new Array[Node](0), splits, bins, 10)
- assert(bestSplits.length === 1)
- assert(bestSplits(0)._1.feature === 0)
- assert(bestSplits(0)._2.gain === 0)
- assert(bestSplits(0)._2.leftImpurity === 0)
- assert(bestSplits(0)._2.rightImpurity === 0)
- assert(bestSplits(0)._3.predict === 1)
+ val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+ null, splits, bins, 10)
+
+ val split = rootNode.split.get
+ assert(split.feature === 0)
+
+ val stats = rootNode.stats.get
+ assert(stats.gain === 0)
+ assert(stats.leftImpurity === 0)
+ assert(stats.rightImpurity === 0)
+ assert(rootNode.predict === 1)
}
test("Second level node building with vs. without groups") {
@@ -459,40 +468,46 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins(0).length === 100)
// Train a 1-node model
- val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100)
+ val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1,
+ numClassesForClassification = 2, maxBins = 100)
val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
- val nodes: Array[Node] = new Array[Node](8)
- nodes(1) = modelOneNode.topNode
- nodes(1).leftNode = None
- nodes(1).rightNode = None
-
- val parentImpurities = Array(0, 0.5, 0.5, 0.5)
+ val rootNodeCopy1 = modelOneNode.topNode.deepCopy()
+ val rootNodeCopy2 = modelOneNode.topNode.deepCopy()
// Single group second level tree construction.
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1, nodes,
- splits, bins, 10)
- assert(bestSplits.length === 2)
- assert(bestSplits(0)._2.gain > 0)
- assert(bestSplits(1)._2.gain > 0)
+ val (rootNode, _) = DecisionTree.findBestSplits(treeInput, metadata, 1,
+ rootNodeCopy1, splits, bins, 10)
+ assert(rootNode.leftNode.nonEmpty)
+ assert(rootNode.rightNode.nonEmpty)
+ val children1 = new Array[Node](2)
+ children1(0) = rootNode.leftNode.get
+ children1(1) = rootNode.rightNode.get
// maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second
// level tree construction.
- val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1,
- nodes, splits, bins, 0)
- assert(bestSplitsWithGroups.length === 2)
- assert(bestSplitsWithGroups(0)._2.gain > 0)
- assert(bestSplitsWithGroups(1)._2.gain > 0)
+ val (rootNode2, _) = DecisionTree.findBestSplits(treeInput, metadata, 1,
+ rootNodeCopy2, splits, bins, 0)
+ assert(rootNode2.leftNode.nonEmpty)
+ assert(rootNode2.rightNode.nonEmpty)
+ val children2 = new Array[Node](2)
+ children2(0) = rootNode2.leftNode.get
+ children2(1) = rootNode2.rightNode.get
// Verify whether the splits obtained using single group and multiple group level
// construction strategies are the same.
- for (i <- 0 until bestSplits.length) {
- assert(bestSplits(i)._1 === bestSplitsWithGroups(i)._1)
- assert(bestSplits(i)._2.gain === bestSplitsWithGroups(i)._2.gain)
- assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity)
- assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity)
- assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity)
- assert(bestSplits(i)._3.predict === bestSplitsWithGroups(i)._3.predict)
+ for (i <- 0 until 2) {
+ assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0)
+ assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0)
+ assert(children1(i).split === children2(i).split)
+ assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty)
+ val stats1 = children1(i).stats.get
+ val stats2 = children2(i).stats.get
+ assert(stats1.gain === stats2.gain)
+ assert(stats1.impurity === stats2.impurity)
+ assert(stats1.leftImpurity === stats2.leftImpurity)
+ assert(stats1.rightImpurity === stats2.rightImpurity)
+ assert(children1(i).predict === children2(i).predict)
}
}
@@ -508,15 +523,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
- new Array[Node](0), splits, bins, 10)
+ val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+ null, splits, bins, 10)
- assert(bestSplits.length === 1)
- val bestSplit = bestSplits(0)._1
- assert(bestSplit.feature === 0)
- assert(bestSplit.categories.length === 1)
- assert(bestSplit.categories.contains(1))
- assert(bestSplit.featureType === Categorical)
+ val split = rootNode.split.get
+ assert(split.feature === 0)
+ assert(split.categories.length === 1)
+ assert(split.categories.contains(1))
+ assert(split.featureType === Categorical)
}
test("Binary classification stump with 1 continuous feature, to check off-by-1 error") {
@@ -573,16 +587,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
- new Array[Node](0), splits, bins, 10)
-
- assert(bestSplits.length === 1)
- val bestSplit = bestSplits(0)._1
- assert(bestSplit.feature === 0)
- assert(bestSplit.categories.length === 1)
- assert(bestSplit.categories.contains(1))
- assert(bestSplit.featureType === Categorical)
- val gain = bestSplits(0)._2
+ val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+ null, splits, bins, 10)
+
+ val split = rootNode.split.get
+ assert(split.feature === 0)
+ assert(split.categories.length === 1)
+ assert(split.categories.contains(1))
+ assert(split.featureType === Categorical)
+
+ val gain = rootNode.stats.get
assert(gain.leftImpurity === 0)
assert(gain.rightImpurity === 0)
}
@@ -600,16 +614,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
- new Array[Node](0), splits, bins, 10)
-
- assert(bestSplits.length === 1)
- val bestSplit = bestSplits(0)._1
+ val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+ null, splits, bins, 10)
- assert(bestSplit.feature === 1)
- assert(bestSplit.featureType === Continuous)
- assert(bestSplit.threshold > 1980)
- assert(bestSplit.threshold < 2020)
+ val split = rootNode.split.get
+ assert(split.feature === 1)
+ assert(split.featureType === Continuous)
+ assert(split.threshold > 1980)
+ assert(split.threshold < 2020)
}
@@ -627,16 +639,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
- new Array[Node](0), splits, bins, 10)
+ val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+ null, splits, bins, 10)
- assert(bestSplits.length === 1)
- val bestSplit = bestSplits(0)._1
-
- assert(bestSplit.feature === 1)
- assert(bestSplit.featureType === Continuous)
- assert(bestSplit.threshold > 1980)
- assert(bestSplit.threshold < 2020)
+ val split = rootNode.split.get
+ assert(split.feature === 1)
+ assert(split.featureType === Continuous)
+ assert(split.threshold > 1980)
+ assert(split.threshold < 2020)
}
test("Multiclass classification stump with 10-ary (ordered) categorical features") {
@@ -652,15 +662,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
- new Array[Node](0), splits, bins, 10)
+ val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+ null, splits, bins, 10)
- assert(bestSplits.length === 1)
- val bestSplit = bestSplits(0)._1
- assert(bestSplit.feature === 0)
- assert(bestSplit.categories.length === 1)
- assert(bestSplit.categories.contains(1.0))
- assert(bestSplit.featureType === Categorical)
+ val split = rootNode.split.get
+ assert(split.feature === 0)
+ assert(split.categories.length === 1)
+ assert(split.categories.contains(1.0))
+ assert(split.featureType === Categorical)
}
test("Multiclass classification tree with 10-ary (ordered) categorical features," +
@@ -698,12 +707,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
- new Array[Node](0), splits, bins, 10)
+ val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+ null, splits, bins, 10)
- assert(bestSplits.length == 1)
- val bestInfoStats = bestSplits(0)._2
- assert(bestInfoStats == InformationGainStats.invalidInformationGainStats)
+ val gain = rootNode.stats.get
+ assert(gain == InformationGainStats.invalidInformationGainStats)
}
test("don't choose split that doesn't satisfy min instance per node requirements") {
@@ -722,14 +730,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
- new Array[Node](0), splits, bins, 10)
+ val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+ null, splits, bins, 10)
- assert(bestSplits.length == 1)
- val bestSplit = bestSplits(0)._1
- val bestSplitStats = bestSplits(0)._1
- assert(bestSplit.feature == 1)
- assert(bestSplitStats != InformationGainStats.invalidInformationGainStats)
+ val split = rootNode.split.get
+ val gain = rootNode.stats.get
+ assert(split.feature == 1)
+ assert(gain != InformationGainStats.invalidInformationGainStats)
}
test("split must satisfy min info gain requirements") {
@@ -754,12 +761,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
- new Array[Node](0), splits, bins, 10)
+ val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+ null, splits, bins, 10)
- assert(bestSplits.length == 1)
- val bestInfoStats = bestSplits(0)._2
- assert(bestInfoStats == InformationGainStats.invalidInformationGainStats)
+ val gain = rootNode.stats.get
+ assert(gain == InformationGainStats.invalidInformationGainStats)
}
}
@@ -786,13 +792,16 @@ object DecisionTreeSuite {
def generateOrderedLabeledPoints(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
for (i <- 0 until 1000) {
- if (i < 600) {
- val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
- arr(i) = lp
+ val label = if (i < 100) {
+ 0.0
+ } else if (i < 500) {
+ 1.0
+ } else if (i < 900) {
+ 0.0
} else {
- val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i))
- arr(i) = lp
+ 1.0
}
+ arr(i) = new LabeledPoint(label, Vectors.dense(i.toDouble, 1000.0 - i))
}
arr
}