aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph.kurata.bradley@gmail.com>2014-09-08 09:47:13 -0700
committerXiangrui Meng <meng@databricks.com>2014-09-08 09:47:13 -0700
commit711356b422c66e2a80377a9f43fce97282460520 (patch)
tree203459a199620778ec407845d77a5040767ea2f5 /mllib/src
parent0d1cc4ae42e1f73538dd8b9b1880ca9e5b124108 (diff)
downloadspark-711356b422c66e2a80377a9f43fce97282460520.tar.gz
spark-711356b422c66e2a80377a9f43fce97282460520.tar.bz2
spark-711356b422c66e2a80377a9f43fce97282460520.zip
[SPARK-3086] [SPARK-3043] [SPARK-3156] [mllib] DecisionTree aggregation improvements
Summary: 1. Variable numBins for each feature [SPARK-3043] 2. Reduced data reshaping in aggregation [SPARK-3043] 3. Choose ordering for ordered categorical features adaptively [SPARK-3156] 4. Changed nodes to use 1-indexing [SPARK-3086] 5. Small clean-ups Note: This PR looks bigger than it is since I moved several functions from inside findBestSplitsPerGroup to outside of it (to make it clear what was being serialized in the aggregation). Speedups: This update helps most when many features use few bins but a few features use many bins. Some example results on speedups with 2M examples, 3.5K features (15-worker EC2 cluster): * Example where old code was reasonably efficient (1/2 continuous, 1/4 binary, 1/4 20-category): 164.813 --> 116.491 sec * Example where old code wasted many bins (1/10 continuous, 81/100 binary, 9/100 20-category): 128.701 --> 39.334 sec Details: (1) Variable numBins for each feature [SPARK-3043] DecisionTreeMetadata now computes a variable numBins for each feature. It also tracks numSplits. (2) Reduced data reshaping in aggregation [SPARK-3043] Added DTStatsAggregator, a wrapper around the aggregate statistics array for easy but efficient indexing. * Added ImpurityAggregator and ImpurityCalculator classes, to make DecisionTree code more oblivious to the type of impurity. * Design note: I originally tried creating Impurity classes which stored data and storing the aggregates in an Array[Array[Array[Impurity]]]. However, this led to significant slowdowns, perhaps because of overhead in creating so many objects. The aggregate statistics are never reshaped, and cumulative sums are computed in-place. Updated the layout of aggregation functions. The update simplifies things by (1) dividing features into ordered/unordered (instead of ordered/unordered/continuous) and (2) making use of the DTStatsAggregator for indexing. For this update, the following functions were refactored: * updateBinForOrderedFeature * updateBinForUnorderedFeature * binaryOrNotCategoricalBinSeqOp * multiclassWithCategoricalBinSeqOp * regressionBinSeqOp The above 5 functions were replaced with: * orderedBinSeqOp * someUnorderedBinSeqOp Other changes: * calculateGainForSplit now treats all feature types the same way. * Eliminated extractLeftRightNodeAggregates. (3) Choose ordering for ordered categorical features adaptively [SPARK-3156] Updated binsToBestSplit(): * This now computes cumulative sums of stats for ordered features. * For ordered categorical features, it chooses an ordering for categories. (This uses to be done by findSplitsBins.) * Uses iterators to shorten code and avoid building an Array[Array[InformationGainStats]]. Side effects: * In findSplitsBins: A sample of the data is only taken for data with continuous features. It is not needed for data with only categorical features. * In findSplitsBins: splits and bins are no longer pre-computed for ordered categorical features since they are not needed. * TreePoint binning is simpler for categorical features. (4) Changed nodes to use 1-indexing [SPARK-3086] Nodes used to be indexed from 0. Now they are indexed from 1. Node indexing functions are now collected in object Node (Node.scala). (5) Small clean-ups Eliminated functions extractNodeInfo() and extractInfoForLowerLevels() to reduce duplicate code. Eliminated InvalidBinIndex since it is no longer used. CC: mengxr manishamde Please let me know if you have thoughts on this—thanks! Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Closes #2125 from jkbradley/dt-opt3alt and squashes the following commits: 42c192a [Joseph K. Bradley] Merge branch 'rfs' into dt-opt3alt d3cc46b [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3alt 00e4404 [Joseph K. Bradley] optimization for TreePoint construction (pre-computing featureArity and isUnordered as arrays) 425716c [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs a2acea5 [Joseph K. Bradley] Small optimizations based on profiling aa4e4df [Joseph K. Bradley] Updated DTStatsAggregator with bug fix (nodeString should not be multiplied by statsSize) 4651154 [Joseph K. Bradley] Changed numBins semantics for unordered features. * Before: numBins = numSplits = (1 << k - 1) - 1 * Now: numBins = 2 * numSplits = 2 * [(1 << k - 1) - 1] * This also involved changing the semantics of: ** DecisionTreeMetadata.numUnorderedBins() 1e3b1c7 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3alt 1485fcc [Joseph K. Bradley] Made some DecisionTree methods private. 92f934f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3alt e676da1 [Joseph K. Bradley] Updated documentation for DecisionTree 37ca845 [Joseph K. Bradley] Fixed problem with how DecisionTree handles ordered categorical features. 105f8ab [Joseph K. Bradley] Removed commented-out getEmptyBinAggregates from DecisionTree 062c31d [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3alt 6d32ccd [Joseph K. Bradley] In DecisionTree.binsToBestSplit, changed loops to iterators to shorten code. 807cd00 [Joseph K. Bradley] Finished DTStatsAggregator, a wrapper around the aggregate statistics for easy but hopefully efficient indexing. Modified old ImpurityAggregator classes and renamed them ImpurityCalculator; added ImpurityAggregator classes which work with DTStatsAggregator but do not store data. Unit tests all succeed. f2166fd [Joseph K. Bradley] still working on DTStatsAggregator 92f7118 [Joseph K. Bradley] Added partly written DTStatsAggregator fd8df30 [Joseph K. Bradley] Moved some aggregation helpers outside of findBestSplitsPerGroup d7c53ee [Joseph K. Bradley] Added more doc for ImpurityAggregator a40f8f1 [Joseph K. Bradley] Changed nodes to be indexed from 1. Tests work. 95cad7c [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3 5f94342 [Joseph K. Bradley] Added treeAggregate since not yet merged from master. Moved node indexing functions to Node. 61c4509 [Joseph K. Bradley] Fixed bugs from merge: missing DT timer call, and numBins setting. Cleaned up DT Suite some. 3ba7166 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3 b314659 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt3 9c83363 [Joseph K. Bradley] partial merge but not done yet 45f7ea7 [Joseph K. Bradley] partial merge, not yet done 5fce635 [Joseph K. Bradley] Merge branch 'dt-opt2' into dt-opt3 26d10dd [Joseph K. Bradley] Removed tree/model/Filter.scala since no longer used. Removed debugging println calls in DecisionTree.scala. 356daba [Joseph K. Bradley] Merge branch 'dt-opt1' into dt-opt2 430d782 [Joseph K. Bradley] Added more debug info on binning error. Added some docs. d036089 [Joseph K. Bradley] Print timing info to logDebug. e66f1b1 [Joseph K. Bradley] TreePoint * Updated doc * Made some methods private 8464a6e [Joseph K. Bradley] Moved TimeTracker to tree/impl/ in its own file, and cleaned it up. Removed debugging println calls from DecisionTree. Made TreePoint extend Serialiable a87e08f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1 dd4d3aa [Joseph K. Bradley] Mid-process in bug fix: bug for binary classification with categorical features * Bug: Categorical features were all treated as ordered for binary classification. This is possible but would require the bin ordering to be determined on-the-fly after the aggregation. Currently, the ordering is determined a priori and fixed for all splits. * (Temp) Fix: Treat low-arity categorical features as unordered for binary classification. * Related change: I removed most tests for isMulticlass in the code. I instead test metadata for whether there are unordered features. * Status: The bug may be fixed, but more testing needs to be done. 438a660 [Joseph K. Bradley] removed subsampling for mnist8m from DT 86e217f [Joseph K. Bradley] added cache to DT input e3c84cc [Joseph K. Bradley] Added stuff fro mnist8m to D T Runner 51ef781 [Joseph K. Bradley] Fixed bug introduced by last commit: Variance impurity calculation was incorrect since counts were swapped accidentally fd65372 [Joseph K. Bradley] Major changes: * Created ImpurityAggregator classes, rather than old aggregates. * Feature split/bin semantics are based on ordered vs. unordered ** E.g.: numSplits = numBins for all unordered features, and numSplits = numBins - 1 for all ordered features. * numBins can differ for each feature c1565a5 [Joseph K. Bradley] Small DecisionTree updates: * Simplification: Updated calculateGainForSplit to take aggregates for a single (feature, split) pair. * Internal doc: findAggForOrderedFeatureClassification b914f3b [Joseph K. Bradley] DecisionTree optimization: eliminated filters + small changes b2ed1f3 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt 0f676e2 [Joseph K. Bradley] Optimizations + Bug fix for DecisionTree 3211f02 [Joseph K. Bradley] Optimizing DecisionTree * Added TreePoint representation to avoid calling findBin multiple times. * (not working yet, but debugging) f61e9d2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing bcf874a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing 511ec85 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing a95bc22 [Joseph K. Bradley] timing for DecisionTree internals
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala1341
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala213
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala73
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala93
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala84
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala84
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala127
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala72
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala85
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala391
11 files changed, 1322 insertions, 1248 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 5cdd258f6c..dd766c12d2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -28,8 +28,9 @@ import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
-import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TimeTracker, TreePoint}
+import org.apache.spark.mllib.tree.impl._
import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity}
+import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
@@ -65,36 +66,41 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
val retaggedInput = input.retag(classOf[LabeledPoint])
val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy)
logDebug("algo = " + strategy.algo)
+ logDebug("maxBins = " + metadata.maxBins)
// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
timer.start("findSplitsBins")
val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata)
- val numBins = bins(0).length
timer.stop("findSplitsBins")
- logDebug("numBins = " + numBins)
+ logDebug("numBins: feature: number of bins")
+ logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
+ s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
+ }.mkString("\n"))
// Bin feature values (TreePoint representation).
// Cache input RDD for speedup during multiple passes.
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
.persist(StorageLevel.MEMORY_AND_DISK)
- val numFeatures = metadata.numFeatures
// depth of the decision tree
val maxDepth = strategy.maxDepth
- // the max number of nodes possible given the depth of the tree
- val maxNumNodes = (2 << maxDepth) - 1
+ require(maxDepth <= 30,
+ s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
+ // Number of nodes to allocate: max number of nodes possible given the depth of the tree, plus 1
+ val maxNumNodesPlus1 = Node.startIndexInLevel(maxDepth + 1)
// Initialize an array to hold parent impurity calculations for each node.
- val parentImpurities = new Array[Double](maxNumNodes)
+ val parentImpurities = new Array[Double](maxNumNodesPlus1)
// dummy value for top node (updated during first split calculation)
- val nodes = new Array[Node](maxNumNodes)
+ val nodes = new Array[Node](maxNumNodesPlus1)
// Calculate level for single group construction
// Max memory usage for aggregates
val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024
logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
- val numElementsPerNode = DecisionTree.getElementsPerNode(metadata, numBins)
+ // TODO: Calculate memory usage more precisely.
+ val numElementsPerNode = DecisionTree.getElementsPerNode(metadata)
logDebug("numElementsPerNode = " + numElementsPerNode)
val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array
@@ -124,26 +130,29 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// Find best split for all nodes at a level.
timer.start("findBestSplits")
- val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities,
- metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
+ val splitsStatsForLevel: Array[(Split, InformationGainStats)] =
+ DecisionTree.findBestSplits(treeInput, parentImpurities,
+ metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
timer.stop("findBestSplits")
- val levelNodeIndexOffset = (1 << level) - 1
+ val levelNodeIndexOffset = Node.startIndexInLevel(level)
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
val nodeIndex = levelNodeIndexOffset + index
- val isLeftChild = level != 0 && nodeIndex % 2 == 1
- val parentNodeIndex = if (isLeftChild) { // -1 for root node
- (nodeIndex - 1) / 2
- } else {
- (nodeIndex - 2) / 2
- }
+
// Extract info for this node (index) at the current level.
timer.start("extractNodeInfo")
- extractNodeInfo(nodeSplitStats, level, index, nodes)
+ val split = nodeSplitStats._1
+ val stats = nodeSplitStats._2
+ val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
+ val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
+ logDebug("Node = " + node)
+ nodes(nodeIndex) = node
timer.stop("extractNodeInfo")
+
if (level != 0) {
// Set parent.
- if (isLeftChild) {
+ val parentNodeIndex = Node.parentIndex(nodeIndex)
+ if (Node.isLeftChild(nodeIndex)) {
nodes(parentNodeIndex).leftNode = Some(nodes(nodeIndex))
} else {
nodes(parentNodeIndex).rightNode = Some(nodes(nodeIndex))
@@ -151,11 +160,21 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
}
// Extract info for nodes at the next lower level.
timer.start("extractInfoForLowerLevels")
- extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities)
+ if (level < maxDepth) {
+ val leftChildIndex = Node.leftChildIndex(nodeIndex)
+ val leftImpurity = stats.leftImpurity
+ logDebug("leftChildIndex = " + leftChildIndex + ", impurity = " + leftImpurity)
+ parentImpurities(leftChildIndex) = leftImpurity
+
+ val rightChildIndex = Node.rightChildIndex(nodeIndex)
+ val rightImpurity = stats.rightImpurity
+ logDebug("rightChildIndex = " + rightChildIndex + ", impurity = " + rightImpurity)
+ parentImpurities(rightChildIndex) = rightImpurity
+ }
timer.stop("extractInfoForLowerLevels")
- logDebug("final best split = " + nodeSplitStats._1)
+ logDebug("final best split = " + split)
}
- require((1 << level) == splitsStatsForLevel.length)
+ require(Node.maxNodesInLevel(level) == splitsStatsForLevel.length)
// Check whether all the nodes at the current level at leaves.
val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
logDebug("all leaf = " + allLeaf)
@@ -171,7 +190,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
logDebug("#####################################")
// Initialize the top or root node of the tree.
- val topNode = nodes(0)
+ val topNode = nodes(1)
// Build the full tree using the node info calculated in the level-wise best split calculations.
topNode.build(nodes)
@@ -183,47 +202,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
new DecisionTreeModel(topNode, strategy.algo)
}
- /**
- * Extract the decision tree node information for the given tree level and node index
- */
- private def extractNodeInfo(
- nodeSplitStats: (Split, InformationGainStats),
- level: Int,
- index: Int,
- nodes: Array[Node]): Unit = {
- val split = nodeSplitStats._1
- val stats = nodeSplitStats._2
- val nodeIndex = (1 << level) - 1 + index
- val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
- val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
- logDebug("Node = " + node)
- nodes(nodeIndex) = node
- }
-
- /**
- * Extract the decision tree node information for the children of the node
- */
- private def extractInfoForLowerLevels(
- level: Int,
- index: Int,
- maxDepth: Int,
- nodeSplitStats: (Split, InformationGainStats),
- parentImpurities: Array[Double]): Unit = {
-
- if (level >= maxDepth) {
- return
- }
-
- val leftNodeIndex = (2 << level) - 1 + 2 * index
- val leftImpurity = nodeSplitStats._2.leftImpurity
- logDebug("leftNodeIndex = " + leftNodeIndex + ", impurity = " + leftImpurity)
- parentImpurities(leftNodeIndex) = leftImpurity
-
- val rightNodeIndex = leftNodeIndex + 1
- val rightImpurity = nodeSplitStats._2.rightImpurity
- logDebug("rightNodeIndex = " + rightNodeIndex + ", impurity = " + rightImpurity)
- parentImpurities(rightNodeIndex) = rightImpurity
- }
}
object DecisionTree extends Serializable with Logging {
@@ -425,9 +403,6 @@ object DecisionTree extends Serializable with Logging {
impurity, maxDepth, maxBins)
}
-
- private val InvalidBinIndex = -1
-
/**
* Returns an array of optimal splits for all nodes at a given level. Splits the task into
* multiple groups if the level-wise training task could lead to memory overflow.
@@ -436,12 +411,12 @@ object DecisionTree extends Serializable with Logging {
* @param parentImpurities Impurities for all parent nodes for the current level
* @param metadata Learning and dataset metadata
* @param level Level of the tree
- * @param splits possible splits for all features
- * @param bins possible bins for all features
+ * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
+ * @param bins possible bins for all features, indexed (numFeatures)(numBins)
* @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
* @return array (over nodes) of splits with best split for each node at a given level.
*/
- protected[tree] def findBestSplits(
+ private[tree] def findBestSplits(
input: RDD[TreePoint],
parentImpurities: Array[Double],
metadata: DecisionTreeMetadata,
@@ -475,14 +450,147 @@ object DecisionTree extends Serializable with Logging {
}
/**
+ * Get the node index corresponding to this data point.
+ * This function mimics prediction, passing an example from the root node down to a node
+ * at the current level being trained; that node's index is returned.
+ *
+ * @param node Node in tree from which to classify the given data point.
+ * @param binnedFeatures Binned feature vector for data point.
+ * @param bins possible bins for all features, indexed (numFeatures)(numBins)
+ * @param unorderedFeatures Set of indices of unordered features.
+ * @return Leaf index if the data point reaches a leaf.
+ * Otherwise, last node reachable in tree matching this example.
+ * Note: This is the global node index, i.e., the index used in the tree.
+ * This index is different from the index used during training a particular
+ * set of nodes in a (level, group).
+ */
+ private def predictNodeIndex(
+ node: Node,
+ binnedFeatures: Array[Int],
+ bins: Array[Array[Bin]],
+ unorderedFeatures: Set[Int]): Int = {
+ if (node.isLeaf) {
+ node.id
+ } else {
+ val featureIndex = node.split.get.feature
+ val splitLeft = node.split.get.featureType match {
+ case Continuous => {
+ val binIndex = binnedFeatures(featureIndex)
+ val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
+ // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold]
+ // We do not need to check lowSplit since bins are separated by splits.
+ featureValueUpperBound <= node.split.get.threshold
+ }
+ case Categorical => {
+ val featureValue = binnedFeatures(featureIndex)
+ node.split.get.categories.contains(featureValue)
+ }
+ case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.")
+ }
+ if (node.leftNode.isEmpty || node.rightNode.isEmpty) {
+ // Return index from next layer of nodes to train
+ if (splitLeft) {
+ Node.leftChildIndex(node.id)
+ } else {
+ Node.rightChildIndex(node.id)
+ }
+ } else {
+ if (splitLeft) {
+ predictNodeIndex(node.leftNode.get, binnedFeatures, bins, unorderedFeatures)
+ } else {
+ predictNodeIndex(node.rightNode.get, binnedFeatures, bins, unorderedFeatures)
+ }
+ }
+ }
+ }
+
+ /**
+ * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features.
+ *
+ * For ordered features, a single bin is updated.
+ * For unordered features, bins correspond to subsets of categories; either the left or right bin
+ * for each subset is updated.
+ *
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (node, feature, bin).
+ * @param treePoint Data point being aggregated.
+ * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
+ * @param bins possible bins for all features, indexed (numFeatures)(numBins)
+ * @param unorderedFeatures Set of indices of unordered features.
+ */
+ private def mixedBinSeqOp(
+ agg: DTStatsAggregator,
+ treePoint: TreePoint,
+ nodeIndex: Int,
+ bins: Array[Array[Bin]],
+ unorderedFeatures: Set[Int]): Unit = {
+ // Iterate over all features.
+ val numFeatures = treePoint.binnedFeatures.size
+ val nodeOffset = agg.getNodeOffset(nodeIndex)
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ if (unorderedFeatures.contains(featureIndex)) {
+ // Unordered feature
+ val featureValue = treePoint.binnedFeatures(featureIndex)
+ val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
+ agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex)
+ // Update the left or right bin for each split.
+ val numSplits = agg.numSplits(featureIndex)
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) {
+ agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label)
+ } else {
+ agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label)
+ }
+ splitIndex += 1
+ }
+ } else {
+ // Ordered feature
+ val binIndex = treePoint.binnedFeatures(featureIndex)
+ agg.nodeUpdate(nodeOffset, featureIndex, binIndex, treePoint.label)
+ }
+ featureIndex += 1
+ }
+ }
+
+ /**
+ * Helper for binSeqOp, for regression and for classification with only ordered features.
+ *
+ * For each feature, the sufficient statistics of one bin are updated.
+ *
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (node, feature, bin).
+ * @param treePoint Data point being aggregated.
+ * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
+ * @return agg
+ */
+ private def orderedBinSeqOp(
+ agg: DTStatsAggregator,
+ treePoint: TreePoint,
+ nodeIndex: Int): Unit = {
+ val label = treePoint.label
+ val nodeOffset = agg.getNodeOffset(nodeIndex)
+ // Iterate over all features.
+ val numFeatures = agg.numFeatures
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ val binIndex = treePoint.binnedFeatures(featureIndex)
+ agg.nodeUpdate(nodeOffset, featureIndex, binIndex, label)
+ featureIndex += 1
+ }
+ }
+
+ /**
* Returns an array of optimal splits for a group of nodes at a given level
*
* @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
* @param parentImpurities Impurities for all parent nodes for the current level
* @param metadata Learning and dataset metadata
* @param level Level of the tree
- * @param splits possible splits for all features
- * @param bins possible bins for all features, indexed as (numFeatures)(numBins)
+ * @param nodes Array of all nodes in the tree. Used for matching data points to nodes.
+ * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
+ * @param bins possible bins for all features, indexed (numFeatures)(numBins)
* @param numGroups total number of node groups at the current level. Default value is set to 1.
* @param groupIndex index of the node group being processed. Default value is set to 0.
* @return array of splits with best splits for all nodes at a given level.
@@ -527,88 +635,22 @@ object DecisionTree extends Serializable with Logging {
// numNodes: Number of nodes in this (level of tree, group),
// where nodes at deeper (larger) levels may be divided into groups.
- val numNodes = (1 << level) / numGroups
+ val numNodes = Node.maxNodesInLevel(level) / numGroups
logDebug("numNodes = " + numNodes)
- // Find the number of features by looking at the first sample.
- val numFeatures = metadata.numFeatures
- logDebug("numFeatures = " + numFeatures)
-
- // numBins: Number of bins = 1 + number of possible splits
- val numBins = bins(0).length
- logDebug("numBins = " + numBins)
-
- val numClasses = metadata.numClasses
- logDebug("numClasses = " + numClasses)
-
- val isMulticlass = metadata.isMulticlass
- logDebug("isMulticlass = " + isMulticlass)
-
- val isMulticlassWithCategoricalFeatures = metadata.isMulticlassWithCategoricalFeatures
- logDebug("isMultiClassWithCategoricalFeatures = " + isMulticlassWithCategoricalFeatures)
+ logDebug("numFeatures = " + metadata.numFeatures)
+ logDebug("numClasses = " + metadata.numClasses)
+ logDebug("isMulticlass = " + metadata.isMulticlass)
+ logDebug("isMulticlassWithCategoricalFeatures = " +
+ metadata.isMulticlassWithCategoricalFeatures)
// shift when more than one group is used at deep tree level
val groupShift = numNodes * groupIndex
- /**
- * Get the node index corresponding to this data point.
- * This function mimics prediction, passing an example from the root node down to a node
- * at the current level being trained; that node's index is returned.
- *
- * @return Leaf index if the data point reaches a leaf.
- * Otherwise, last node reachable in tree matching this example.
- */
- def predictNodeIndex(node: Node, binnedFeatures: Array[Int]): Int = {
- if (node.isLeaf) {
- node.id
- } else {
- val featureIndex = node.split.get.feature
- val splitLeft = node.split.get.featureType match {
- case Continuous => {
- val binIndex = binnedFeatures(featureIndex)
- val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
- // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold]
- // We do not need to check lowSplit since bins are separated by splits.
- featureValueUpperBound <= node.split.get.threshold
- }
- case Categorical => {
- val featureValue = if (metadata.isUnordered(featureIndex)) {
- binnedFeatures(featureIndex)
- } else {
- val binIndex = binnedFeatures(featureIndex)
- bins(featureIndex)(binIndex).category
- }
- node.split.get.categories.contains(featureValue)
- }
- case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.")
- }
- if (node.leftNode.isEmpty || node.rightNode.isEmpty) {
- // Return index from next layer of nodes to train
- if (splitLeft) {
- node.id * 2 + 1 // left
- } else {
- node.id * 2 + 2 // right
- }
- } else {
- if (splitLeft) {
- predictNodeIndex(node.leftNode.get, binnedFeatures)
- } else {
- predictNodeIndex(node.rightNode.get, binnedFeatures)
- }
- }
- }
- }
-
- def nodeIndexToLevel(idx: Int): Int = {
- if (idx == 0) {
- 0
- } else {
- math.floor(math.log(idx) / math.log(2)).toInt
- }
- }
-
- // Used for treePointToNodeIndex
- val levelOffset = (1 << level) - 1
+ // Used for treePointToNodeIndex to get an index for this (level, group).
+ // - Node.startIndexInLevel(level) gives the global index offset for nodes at this level.
+ // - groupShift corrects for groups in this level before the current group.
+ val globalNodeIndexOffset = Node.startIndexInLevel(level) + groupShift
/**
* Find the node index for the given example.
@@ -619,661 +661,254 @@ object DecisionTree extends Serializable with Logging {
if (level == 0) {
0
} else {
- val globalNodeIndex = predictNodeIndex(nodes(0), treePoint.binnedFeatures)
- // Get index for this (level, group).
- globalNodeIndex - levelOffset - groupShift
- }
- }
-
- /**
- * Increment aggregate in location for (node, feature, bin, label).
- *
- * @param treePoint Data point being aggregated.
- * @param agg Array storing aggregate calculation, of size:
- * numClasses * numBins * numFeatures * numNodes.
- * Indexed by (node, feature, bin, label) where label is the least significant bit.
- * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
- */
- def updateBinForOrderedFeature(
- treePoint: TreePoint,
- agg: Array[Double],
- nodeIndex: Int,
- featureIndex: Int): Unit = {
- // Update the left or right count for one bin.
- val aggIndex =
- numClasses * numBins * numFeatures * nodeIndex +
- numClasses * numBins * featureIndex +
- numClasses * treePoint.binnedFeatures(featureIndex) +
- treePoint.label.toInt
- agg(aggIndex) += 1
- }
-
- /**
- * Increment aggregate in location for (nodeIndex, featureIndex, [bins], label),
- * where [bins] ranges over all bins.
- * Updates left or right side of aggregate depending on split.
- *
- * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
- * @param treePoint Data point being aggregated.
- * @param agg Indexed by (left/right, node, feature, bin, label)
- * where label is the least significant bit.
- * The left/right specifier is a 0/1 index indicating left/right child info.
- * @param rightChildShift Offset for right side of agg.
- */
- def updateBinForUnorderedFeature(
- nodeIndex: Int,
- featureIndex: Int,
- treePoint: TreePoint,
- agg: Array[Double],
- rightChildShift: Int): Unit = {
- val featureValue = treePoint.binnedFeatures(featureIndex)
- // Update the left or right count for one bin.
- val aggShift =
- numClasses * numBins * numFeatures * nodeIndex +
- numClasses * numBins * featureIndex +
- treePoint.label.toInt
- // Find all matching bins and increment their values
- val featureCategories = metadata.featureArity(featureIndex)
- val numCategoricalBins = (1 << featureCategories - 1) - 1
- var binIndex = 0
- while (binIndex < numCategoricalBins) {
- val aggIndex = aggShift + binIndex * numClasses
- if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) {
- agg(aggIndex) += 1
- } else {
- agg(rightChildShift + aggIndex) += 1
- }
- binIndex += 1
- }
- }
-
- /**
- * Helper for binSeqOp.
- *
- * @param agg Array storing aggregate calculation, of size:
- * numClasses * numBins * numFeatures * numNodes.
- * Indexed by (node, feature, bin, label) where label is the least significant bit.
- * @param treePoint Data point being aggregated.
- * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
- */
- def binaryOrNotCategoricalBinSeqOp(
- agg: Array[Double],
- treePoint: TreePoint,
- nodeIndex: Int): Unit = {
- // Iterate over all features.
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex)
- featureIndex += 1
- }
- }
-
- val rightChildShift = numClasses * numBins * numFeatures * numNodes
-
- /**
- * Helper for binSeqOp.
- *
- * @param agg Array storing aggregate calculation.
- * For ordered features, this is of size:
- * numClasses * numBins * numFeatures * numNodes.
- * For unordered features, this is of size:
- * 2 * numClasses * numBins * numFeatures * numNodes.
- * @param treePoint Data point being aggregated.
- * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
- */
- def multiclassWithCategoricalBinSeqOp(
- agg: Array[Double],
- treePoint: TreePoint,
- nodeIndex: Int): Unit = {
- val label = treePoint.label
- // Iterate over all features.
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- if (metadata.isUnordered(featureIndex)) {
- updateBinForUnorderedFeature(nodeIndex, featureIndex, treePoint, agg, rightChildShift)
- } else {
- updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex)
- }
- featureIndex += 1
- }
- }
-
- /**
- * Performs a sequential aggregation over a partition for regression.
- * For l nodes, k features,
- * the count, sum, sum of squares of one of the p bins is incremented.
- *
- * @param agg Array storing aggregate calculation, updated by this function.
- * Size: 3 * numBins * numFeatures * numNodes
- * @param treePoint Data point being aggregated.
- * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
- * @return agg
- */
- def regressionBinSeqOp(agg: Array[Double], treePoint: TreePoint, nodeIndex: Int): Unit = {
- val label = treePoint.label
- // Iterate over all features.
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- // Update count, sum, and sum^2 for one bin.
- val binIndex = treePoint.binnedFeatures(featureIndex)
- val aggIndex =
- 3 * numBins * numFeatures * nodeIndex +
- 3 * numBins * featureIndex +
- 3 * binIndex
- agg(aggIndex) += 1
- agg(aggIndex + 1) += label
- agg(aggIndex + 2) += label * label
- featureIndex += 1
+ val globalNodeIndex =
+ predictNodeIndex(nodes(1), treePoint.binnedFeatures, bins, metadata.unorderedFeatures)
+ globalNodeIndex - globalNodeIndexOffset
}
}
/**
* Performs a sequential aggregation over a partition.
- * For l nodes, k features,
- * For classification:
- * Either the left count or the right count of one of the bins is
- * incremented based upon whether the feature is classified as 0 or 1.
- * For regression:
- * The count, sum, sum of squares of one of the bins is incremented.
*
- * @param agg Array storing aggregate calculation, updated by this function.
- * Size for classification:
- * numClasses * numBins * numFeatures * numNodes for ordered features, or
- * 2 * numClasses * numBins * numFeatures * numNodes for unordered features.
- * Size for regression:
- * 3 * numBins * numFeatures * numNodes.
+ * Each data point contributes to one node. For each feature,
+ * the aggregate sufficient statistics are updated for the relevant bins.
+ *
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (node, feature, bin).
* @param treePoint Data point being aggregated.
* @return agg
*/
- def binSeqOp(agg: Array[Double], treePoint: TreePoint): Array[Double] = {
+ def binSeqOp(
+ agg: DTStatsAggregator,
+ treePoint: TreePoint): DTStatsAggregator = {
val nodeIndex = treePointToNodeIndex(treePoint)
// If the example does not reach this level, then nodeIndex < 0.
// If the example reaches this level but is handled in a different group,
// then either nodeIndex < 0 (previous group) or nodeIndex >= numNodes (later group).
if (nodeIndex >= 0 && nodeIndex < numNodes) {
- if (metadata.isClassification) {
- if (isMulticlassWithCategoricalFeatures) {
- multiclassWithCategoricalBinSeqOp(agg, treePoint, nodeIndex)
- } else {
- binaryOrNotCategoricalBinSeqOp(agg, treePoint, nodeIndex)
- }
+ if (metadata.unorderedFeatures.isEmpty) {
+ orderedBinSeqOp(agg, treePoint, nodeIndex)
} else {
- regressionBinSeqOp(agg, treePoint, nodeIndex)
+ mixedBinSeqOp(agg, treePoint, nodeIndex, bins, metadata.unorderedFeatures)
}
}
agg
}
- // Calculate bin aggregate length for classification or regression.
- val binAggregateLength = numNodes * getElementsPerNode(metadata, numBins)
- logDebug("binAggregateLength = " + binAggregateLength)
-
- /**
- * Combines the aggregates from partitions.
- * @param agg1 Array containing aggregates from one or more partitions
- * @param agg2 Array containing aggregates from one or more partitions
- * @return Combined aggregate from agg1 and agg2
- */
- def binCombOp(agg1: Array[Double], agg2: Array[Double]): Array[Double] = {
- var index = 0
- val combinedAggregate = new Array[Double](binAggregateLength)
- while (index < binAggregateLength) {
- combinedAggregate(index) = agg1(index) + agg2(index)
- index += 1
- }
- combinedAggregate
- }
-
// Calculate bin aggregates.
timer.start("aggregation")
- val binAggregates = {
- input.treeAggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp)
+ val binAggregates: DTStatsAggregator = {
+ val initAgg = new DTStatsAggregator(metadata, numNodes)
+ input.treeAggregate(initAgg)(binSeqOp, DTStatsAggregator.binCombOp)
}
timer.stop("aggregation")
- logDebug("binAggregates.length = " + binAggregates.length)
- /**
- * Calculate the information gain for a given (feature, split) based upon left/right aggregates.
- * @param leftNodeAgg left node aggregates for this (feature, split)
- * @param rightNodeAgg right node aggregate for this (feature, split)
- * @param topImpurity impurity of the parent node
- * @return information gain and statistics for all splits
- */
- def calculateGainForSplit(
- leftNodeAgg: Array[Double],
- rightNodeAgg: Array[Double],
- topImpurity: Double): InformationGainStats = {
- if (metadata.isClassification) {
- val leftTotalCount = leftNodeAgg.sum
- val rightTotalCount = rightNodeAgg.sum
-
- val impurity = {
- if (level > 0) {
- topImpurity
- } else {
- // Calculate impurity for root node.
- val rootNodeCounts = new Array[Double](numClasses)
- var classIndex = 0
- while (classIndex < numClasses) {
- rootNodeCounts(classIndex) = leftNodeAgg(classIndex) + rightNodeAgg(classIndex)
- classIndex += 1
- }
- metadata.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount)
- }
- }
-
- val totalCount = leftTotalCount + rightTotalCount
- if (totalCount == 0) {
- // Return arbitrary prediction.
- return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
- }
-
- // Sum of count for each label
- val leftrightNodeAgg: Array[Double] =
- leftNodeAgg.zip(rightNodeAgg).map { case (leftCount, rightCount) =>
- leftCount + rightCount
- }
-
- def indexOfLargestArrayElement(array: Array[Double]): Int = {
- val result = array.foldLeft(-1, Double.MinValue, 0) {
- case ((maxIndex, maxValue, currentIndex), currentValue) =>
- if (currentValue > maxValue) {
- (currentIndex, currentValue, currentIndex + 1)
- } else {
- (maxIndex, maxValue, currentIndex + 1)
- }
- }
- if (result._1 < 0) {
- throw new RuntimeException("DecisionTree internal error:" +
- " calculateGainForSplit failed in indexOfLargestArrayElement")
- }
- result._1
- }
-
- val predict = indexOfLargestArrayElement(leftrightNodeAgg)
- val prob = leftrightNodeAgg(predict) / totalCount
-
- val leftImpurity = if (leftTotalCount == 0) {
- topImpurity
- } else {
- metadata.impurity.calculate(leftNodeAgg, leftTotalCount)
- }
- val rightImpurity = if (rightTotalCount == 0) {
- topImpurity
- } else {
- metadata.impurity.calculate(rightNodeAgg, rightTotalCount)
- }
-
- val leftWeight = leftTotalCount / totalCount
- val rightWeight = rightTotalCount / totalCount
-
- val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
-
- new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
-
- } else {
- // Regression
-
- val leftCount = leftNodeAgg(0)
- val leftSum = leftNodeAgg(1)
- val leftSumSquares = leftNodeAgg(2)
+ // Calculate best splits for all nodes at a given level
+ timer.start("chooseSplits")
+ val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
+ // Iterating over all nodes at this level
+ var nodeIndex = 0
+ while (nodeIndex < numNodes) {
+ val nodeImpurity = parentImpurities(globalNodeIndexOffset + nodeIndex)
+ logDebug("node impurity = " + nodeImpurity)
+ bestSplits(nodeIndex) =
+ binsToBestSplit(binAggregates, nodeIndex, nodeImpurity, level, metadata, splits)
+ logDebug("best split = " + bestSplits(nodeIndex)._1)
+ nodeIndex += 1
+ }
+ timer.stop("chooseSplits")
- val rightCount = rightNodeAgg(0)
- val rightSum = rightNodeAgg(1)
- val rightSumSquares = rightNodeAgg(2)
+ bestSplits
+ }
- val impurity = {
- if (level > 0) {
- topImpurity
- } else {
- // Calculate impurity for root node.
- val count = leftCount + rightCount
- val sum = leftSum + rightSum
- val sumSquares = leftSumSquares + rightSumSquares
- metadata.impurity.calculate(count, sum, sumSquares)
- }
- }
+ /**
+ * Calculate the information gain for a given (feature, split) based upon left/right aggregates.
+ * @param leftImpurityCalculator left node aggregates for this (feature, split)
+ * @param rightImpurityCalculator right node aggregate for this (feature, split)
+ * @param topImpurity impurity of the parent node
+ * @return information gain and statistics for all splits
+ */
+ private def calculateGainForSplit(
+ leftImpurityCalculator: ImpurityCalculator,
+ rightImpurityCalculator: ImpurityCalculator,
+ topImpurity: Double,
+ level: Int,
+ metadata: DecisionTreeMetadata): InformationGainStats = {
- if (leftCount == 0) {
- return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,
- rightSum / rightCount)
- }
- if (rightCount == 0) {
- return new InformationGainStats(0, topImpurity, topImpurity,
- Double.MinValue, leftSum / leftCount)
- }
+ val leftCount = leftImpurityCalculator.count
+ val rightCount = rightImpurityCalculator.count
- val leftImpurity = metadata.impurity.calculate(leftCount, leftSum, leftSumSquares)
- val rightImpurity = metadata.impurity.calculate(rightCount, rightSum, rightSumSquares)
+ val totalCount = leftCount + rightCount
+ if (totalCount == 0) {
+ // Return arbitrary prediction.
+ return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
+ }
- val leftWeight = leftCount.toDouble / (leftCount + rightCount)
- val rightWeight = rightCount.toDouble / (leftCount + rightCount)
+ val parentNodeAgg = leftImpurityCalculator.copy
+ parentNodeAgg.add(rightImpurityCalculator)
+ // impurity of parent node
+ val impurity = if (level > 0) {
+ topImpurity
+ } else {
+ parentNodeAgg.calculate()
+ }
- val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+ val predict = parentNodeAgg.predict
+ val prob = parentNodeAgg.prob(predict)
- val predict = (leftSum + rightSum) / (leftCount + rightCount)
- new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
- }
- }
+ val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
+ val rightImpurity = rightImpurityCalculator.calculate()
- /**
- * Extracts left and right split aggregates.
- * @param binData Aggregate array slice from getBinDataForNode.
- * For classification:
- * For unordered features, this is leftChildData ++ rightChildData,
- * each of which is indexed by (feature, split/bin, class),
- * with class being the least significant bit.
- * For ordered features, this is of size numClasses * numBins * numFeatures.
- * For regression:
- * This is of size 2 * numFeatures * numBins.
- * @return (leftNodeAgg, rightNodeAgg) pair of arrays.
- * For classification, each array is of size (numFeatures, (numBins - 1), numClasses).
- * For regression, each array is of size (numFeatures, (numBins - 1), 3).
- *
- */
- def extractLeftRightNodeAggregates(
- binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = {
-
-
- /**
- * The input binData is indexed as (feature, bin, class).
- * This computes cumulative sums over splits.
- * Each (feature, class) pair is handled separately.
- * Note: numSplits = numBins - 1.
- * @param leftNodeAgg Each (feature, class) slice is an array over splits.
- * Element i (i = 0, ..., numSplits - 2) is set to be
- * the cumulative sum (from left) over binData for bins 0, ..., i.
- * @param rightNodeAgg Each (feature, class) slice is an array over splits.
- * Element i (i = 1, ..., numSplits - 1) is set to be
- * the cumulative sum (from right) over binData for bins
- * numBins - 1, ..., numBins - 1 - i.
- */
- def findAggForOrderedFeatureClassification(
- leftNodeAgg: Array[Array[Array[Double]]],
- rightNodeAgg: Array[Array[Array[Double]]],
- featureIndex: Int) {
-
- // shift for this featureIndex
- val shift = numClasses * featureIndex * numBins
-
- var classIndex = 0
- while (classIndex < numClasses) {
- // left node aggregate for the lowest split
- leftNodeAgg(featureIndex)(0)(classIndex) = binData(shift + classIndex)
- // right node aggregate for the highest split
- rightNodeAgg(featureIndex)(numBins - 2)(classIndex)
- = binData(shift + (numClasses * (numBins - 1)) + classIndex)
- classIndex += 1
- }
+ val leftWeight = leftCount / totalCount.toDouble
+ val rightWeight = rightCount / totalCount.toDouble
- // Iterate over all splits.
- var splitIndex = 1
- while (splitIndex < numBins - 1) {
- // calculating left node aggregate for a split as a sum of left node aggregate of a
- // lower split and the left bin aggregate of a bin where the split is a high split
- var innerClassIndex = 0
- while (innerClassIndex < numClasses) {
- leftNodeAgg(featureIndex)(splitIndex)(innerClassIndex)
- = binData(shift + numClasses * splitIndex + innerClassIndex) +
- leftNodeAgg(featureIndex)(splitIndex - 1)(innerClassIndex)
- rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(innerClassIndex) =
- binData(shift + (numClasses * (numBins - 1 - splitIndex) + innerClassIndex)) +
- rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(innerClassIndex)
- innerClassIndex += 1
- }
- splitIndex += 1
- }
- }
+ val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
- /**
- * Reshape binData for this feature.
- * Indexes binData as (feature, split, class) with class as the least significant bit.
- * @param leftNodeAgg leftNodeAgg(featureIndex)(splitIndex)(classIndex) = aggregate value
- */
- def findAggForUnorderedFeatureClassification(
- leftNodeAgg: Array[Array[Array[Double]]],
- rightNodeAgg: Array[Array[Array[Double]]],
- featureIndex: Int) {
-
- val rightChildShift = numClasses * numBins * numFeatures
- var splitIndex = 0
- while (splitIndex < numBins - 1) {
- var classIndex = 0
- while (classIndex < numClasses) {
- // shift for this featureIndex
- val shift = numClasses * featureIndex * numBins + splitIndex * numClasses
- val leftBinValue = binData(shift + classIndex)
- val rightBinValue = binData(rightChildShift + shift + classIndex)
- leftNodeAgg(featureIndex)(splitIndex)(classIndex) = leftBinValue
- rightNodeAgg(featureIndex)(splitIndex)(classIndex) = rightBinValue
- classIndex += 1
- }
- splitIndex += 1
- }
- }
+ new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
+ }
- def findAggForRegression(
- leftNodeAgg: Array[Array[Array[Double]]],
- rightNodeAgg: Array[Array[Array[Double]]],
- featureIndex: Int) {
-
- // shift for this featureIndex
- val shift = 3 * featureIndex * numBins
- // left node aggregate for the lowest split
- leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0)
- leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1)
- leftNodeAgg(featureIndex)(0)(2) = binData(shift + 2)
-
- // right node aggregate for the highest split
- rightNodeAgg(featureIndex)(numBins - 2)(0) =
- binData(shift + (3 * (numBins - 1)))
- rightNodeAgg(featureIndex)(numBins - 2)(1) =
- binData(shift + (3 * (numBins - 1)) + 1)
- rightNodeAgg(featureIndex)(numBins - 2)(2) =
- binData(shift + (3 * (numBins - 1)) + 2)
-
- // Iterate over all splits.
- var splitIndex = 1
- while (splitIndex < numBins - 1) {
- var i = 0 // index for regression histograms
- while (i < 3) { // count, sum, sum^2
- // calculating left node aggregate for a split as a sum of left node aggregate of a
- // lower split and the left bin aggregate of a bin where the split is a high split
- leftNodeAgg(featureIndex)(splitIndex)(i) = binData(shift + 3 * splitIndex + i) +
- leftNodeAgg(featureIndex)(splitIndex - 1)(i)
- // calculating right node aggregate for a split as a sum of right node aggregate of a
- // higher split and the right bin aggregate of a bin where the split is a low split
- rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(i) =
- binData(shift + (3 * (numBins - 1 - splitIndex) + i)) +
- rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(i)
- i += 1
- }
- splitIndex += 1
- }
- }
+ /**
+ * Find the best split for a node.
+ * @param binAggregates Bin statistics.
+ * @param nodeIndex Index for node to split in this (level, group).
+ * @param nodeImpurity Impurity of the node (nodeIndex).
+ * @return tuple for best split: (Split, information gain)
+ */
+ private def binsToBestSplit(
+ binAggregates: DTStatsAggregator,
+ nodeIndex: Int,
+ nodeImpurity: Double,
+ level: Int,
+ metadata: DecisionTreeMetadata,
+ splits: Array[Array[Split]]): (Split, InformationGainStats) = {
- if (metadata.isClassification) {
- // Initialize left and right split aggregates.
- val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
- val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- if (metadata.isUnordered(featureIndex)) {
- findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
- } else {
- findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
- }
- featureIndex += 1
- }
- (leftNodeAgg, rightNodeAgg)
- } else {
- // Regression
- // Initialize left and right split aggregates.
- val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
- val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
- // Iterate over all features.
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex)
- featureIndex += 1
- }
- (leftNodeAgg, rightNodeAgg)
- }
- }
+ logDebug("node impurity = " + nodeImpurity)
- /**
- * Calculates information gain for all nodes splits.
- */
- def calculateGainsForAllNodeSplits(
- leftNodeAgg: Array[Array[Array[Double]]],
- rightNodeAgg: Array[Array[Array[Double]]],
- nodeImpurity: Double): Array[Array[InformationGainStats]] = {
- val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)
-
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- val numSplitsForFeature = getNumSplitsForFeature(featureIndex)
+ // For each (feature, split), calculate the gain, and select the best (feature, split).
+ Range(0, metadata.numFeatures).map { featureIndex =>
+ val numSplits = metadata.numSplits(featureIndex)
+ if (metadata.isContinuous(featureIndex)) {
+ // Cumulative sum (scanLeft) of bin statistics.
+ // Afterwards, binAggregates for a bin is the sum of aggregates for
+ // that bin + all preceding bins.
+ val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex)
var splitIndex = 0
- while (splitIndex < numSplitsForFeature) {
- gains(featureIndex)(splitIndex) =
- calculateGainForSplit(leftNodeAgg(featureIndex)(splitIndex),
- rightNodeAgg(featureIndex)(splitIndex), nodeImpurity)
+ while (splitIndex < numSplits) {
+ binAggregates.mergeForNodeFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
splitIndex += 1
}
- featureIndex += 1
- }
- gains
- }
-
- /**
- * Get the number of splits for a feature.
- */
- def getNumSplitsForFeature(featureIndex: Int): Int = {
- if (metadata.isContinuous(featureIndex)) {
- numBins - 1
+ // Find best split.
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { case splitIdx =>
+ val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
+ val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
+ rightChildStats.subtract(leftChildStats)
+ val gainStats =
+ calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
+ (splitIdx, gainStats)
+ }.maxBy(_._2.gain)
+ (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
+ } else if (metadata.isUnordered(featureIndex)) {
+ // Unordered categorical feature
+ val (leftChildOffset, rightChildOffset) =
+ binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex)
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { splitIndex =>
+ val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
+ val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
+ val gainStats =
+ calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
+ (splitIndex, gainStats)
+ }.maxBy(_._2.gain)
+ (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else {
- // Categorical feature
- val featureCategories = metadata.featureArity(featureIndex)
- if (metadata.isUnordered(featureIndex)) {
- (1 << featureCategories - 1) - 1
- } else {
- featureCategories
- }
- }
- }
-
- /**
- * Find the best split for a node.
- * @param binData Bin data slice for this node, given by getBinDataForNode.
- * @param nodeImpurity impurity of the top node
- * @return tuple of split and information gain
- */
- def binsToBestSplit(
- binData: Array[Double],
- nodeImpurity: Double): (Split, InformationGainStats) = {
-
- logDebug("node impurity = " + nodeImpurity)
-
- // Extract left right node aggregates.
- val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
-
- // Calculate gains for all splits.
- val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity)
-
- val (bestFeatureIndex, bestSplitIndex, gainStats) = {
- // Initialize with infeasible values.
- var bestFeatureIndex = Int.MinValue
- var bestSplitIndex = Int.MinValue
- var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0)
- // Iterate over features.
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- // Iterate over all splits.
- var splitIndex = 0
- val numSplitsForFeature = getNumSplitsForFeature(featureIndex)
- while (splitIndex < numSplitsForFeature) {
- val gainStats = gains(featureIndex)(splitIndex)
- if (gainStats.gain > bestGainStats.gain) {
- bestGainStats = gainStats
- bestFeatureIndex = featureIndex
- bestSplitIndex = splitIndex
+ // Ordered categorical feature
+ val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex)
+ val numBins = metadata.numBins(featureIndex)
+
+ /* Each bin is one category (feature value).
+ * The bins are ordered based on centroidForCategories, and this ordering determines which
+ * splits are considered. (With K categories, we consider K - 1 possible splits.)
+ *
+ * centroidForCategories is a list: (category, centroid)
+ */
+ val centroidForCategories = if (metadata.isMulticlass) {
+ // For categorical variables in multiclass classification,
+ // the bins are ordered by the impurity of their corresponding labels.
+ Range(0, numBins).map { case featureValue =>
+ val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+ val centroid = if (categoryStats.count != 0) {
+ categoryStats.calculate()
+ } else {
+ Double.MaxValue
}
- splitIndex += 1
+ (featureValue, centroid)
+ }
+ } else { // regression or binary classification
+ // For categorical variables in regression and binary classification,
+ // the bins are ordered by the centroid of their corresponding labels.
+ Range(0, numBins).map { case featureValue =>
+ val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+ val centroid = if (categoryStats.count != 0) {
+ categoryStats.predict
+ } else {
+ Double.MaxValue
+ }
+ (featureValue, centroid)
}
- featureIndex += 1
}
- (bestFeatureIndex, bestSplitIndex, bestGainStats)
- }
- logDebug("best split = " + splits(bestFeatureIndex)(bestSplitIndex))
- logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex))
+ logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
- (splits(bestFeatureIndex)(bestSplitIndex), gainStats)
- }
+ // bins sorted by centroids
+ val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
- /**
- * Get bin data for one node.
- */
- def getBinDataForNode(node: Int): Array[Double] = {
- if (metadata.isClassification) {
- if (isMulticlassWithCategoricalFeatures) {
- val shift = numClasses * node * numBins * numFeatures
- val rightChildShift = numClasses * numBins * numFeatures * numNodes
- val binsForNode = {
- val leftChildData
- = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
- val rightChildData
- = binAggregates.slice(rightChildShift + shift,
- rightChildShift + shift + numClasses * numBins * numFeatures)
- leftChildData ++ rightChildData
- }
- binsForNode
- } else {
- val shift = numClasses * node * numBins * numFeatures
- val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
- binsForNode
+ logDebug("Sorted centroids for categorical variable = " +
+ categoriesSortedByCentroid.mkString(","))
+
+ // Cumulative sum (scanLeft) of bin statistics.
+ // Afterwards, binAggregates for a bin is the sum of aggregates for
+ // that bin + all preceding bins.
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ val currentCategory = categoriesSortedByCentroid(splitIndex)._1
+ val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
+ binAggregates.mergeForNodeFeature(nodeFeatureOffset, nextCategory, currentCategory)
+ splitIndex += 1
}
- } else {
- // Regression
- val shift = 3 * node * numBins * numFeatures
- val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures)
- binsForNode
+ // lastCategory = index of bin with total aggregates for this (node, feature)
+ val lastCategory = categoriesSortedByCentroid.last._1
+ // Find best split.
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { splitIndex =>
+ val featureValue = categoriesSortedByCentroid(splitIndex)._1
+ val leftChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+ val rightChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
+ rightChildStats.subtract(leftChildStats)
+ val gainStats =
+ calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
+ (splitIndex, gainStats)
+ }.maxBy(_._2.gain)
+ val categoriesForSplit =
+ categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
+ val bestFeatureSplit =
+ new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
+ (bestFeatureSplit, bestFeatureGainStats)
}
- }
-
- // Calculate best splits for all nodes at a given level
- timer.start("chooseSplits")
- val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
- // Iterating over all nodes at this level
- var node = 0
- while (node < numNodes) {
- val nodeImpurityIndex = (1 << level) - 1 + node + groupShift
- val binsForNode: Array[Double] = getBinDataForNode(node)
- logDebug("nodeImpurityIndex = " + nodeImpurityIndex)
- val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
- logDebug("parent node impurity = " + parentNodeImpurity)
- bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
- node += 1
- }
- timer.stop("chooseSplits")
-
- bestSplits
+ }.maxBy(_._2.gain)
}
/**
* Get the number of values to be stored per node in the bin aggregates.
- *
- * @param numBins Number of bins = 1 + number of possible splits.
*/
- private def getElementsPerNode(metadata: DecisionTreeMetadata, numBins: Int): Int = {
+ private def getElementsPerNode(metadata: DecisionTreeMetadata): Int = {
+ val totalBins = metadata.numBins.sum
if (metadata.isClassification) {
- if (metadata.isMulticlassWithCategoricalFeatures) {
- 2 * metadata.numClasses * numBins * metadata.numFeatures
- } else {
- metadata.numClasses * numBins * metadata.numFeatures
- }
+ metadata.numClasses * totalBins
} else {
- 3 * numBins * metadata.numFeatures
+ 3 * totalBins
}
}
@@ -1284,6 +919,7 @@ object DecisionTree extends Serializable with Logging {
* Continuous features:
* For each feature, there are numBins - 1 possible splits representing the possible binary
* decisions at each node in the tree.
+ * This finds locations (feature values) for splits using a subsample of the data.
*
* Categorical features:
* For each feature, there is 1 bin per split.
@@ -1292,7 +928,6 @@ object DecisionTree extends Serializable with Logging {
* For multiclass classification with a low-arity feature
* (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
* the feature is split based on subsets of categories.
- * There are (1 << maxFeatureValue - 1) - 1 splits.
* (b) "ordered features"
* For regression and binary classification,
* and for multiclass classification with a high-arity feature,
@@ -1302,7 +937,7 @@ object DecisionTree extends Serializable with Logging {
* @param metadata Learning and dataset metadata
* @return A tuple of (splits, bins).
* Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
- * of size (numFeatures, numBins - 1).
+ * of size (numFeatures, numSplits).
* Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]]
* of size (numFeatures, numBins).
*/
@@ -1310,84 +945,80 @@ object DecisionTree extends Serializable with Logging {
input: RDD[LabeledPoint],
metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = {
- val count = input.count()
+ logDebug("isMulticlass = " + metadata.isMulticlass)
- // Find the number of features by looking at the first sample
- val numFeatures = input.take(1)(0).features.size
-
- val maxBins = metadata.maxBins
- val numBins = if (maxBins <= count) maxBins else count.toInt
- logDebug("numBins = " + numBins)
- val isMulticlass = metadata.isMulticlass
- logDebug("isMulticlass = " + isMulticlass)
-
- /*
- * Ensure numBins is always greater than the categories. For multiclass classification,
- * numBins should be greater than 2^(maxCategories - 1) - 1.
- * It's a limitation of the current implementation but a reasonable trade-off since features
- * with large number of categories get favored over continuous features.
- *
- * This needs to be checked here instead of in Strategy since numBins can be determined
- * by the number of training examples.
- * TODO: Allow this case, where we simply will know nothing about some categories.
- */
- if (metadata.featureArity.size > 0) {
- val maxCategoriesForFeatures = metadata.featureArity.maxBy(_._2)._2
- require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " +
- "in categorical features")
- }
-
- // Calculate the number of sample for approximate quantile calculation.
- val requiredSamples = numBins*numBins
- val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0
- logDebug("fraction of data used for calculating quantiles = " + fraction)
+ val numFeatures = metadata.numFeatures
- // sampled input for RDD calculation
- val sampledInput =
+ // Sample the input only if there are continuous features.
+ val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous)
+ val sampledInput = if (hasContinuousFeatures) {
+ // Calculate the number of samples for approximate quantile calculation.
+ val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
+ val fraction = if (requiredSamples < metadata.numExamples) {
+ requiredSamples.toDouble / metadata.numExamples
+ } else {
+ 1.0
+ }
+ logDebug("fraction of data used for calculating quantiles = " + fraction)
input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect()
- val numSamples = sampledInput.length
-
- val stride: Double = numSamples.toDouble / numBins
- logDebug("stride = " + stride)
+ } else {
+ new Array[LabeledPoint](0)
+ }
metadata.quantileStrategy match {
case Sort =>
- val splits = Array.ofDim[Split](numFeatures, numBins - 1)
- val bins = Array.ofDim[Bin](numFeatures, numBins)
+ val splits = new Array[Array[Split]](numFeatures)
+ val bins = new Array[Array[Bin]](numFeatures)
// Find all splits.
-
// Iterate over all features.
var featureIndex = 0
while (featureIndex < numFeatures) {
- // Check whether the feature is continuous.
- val isFeatureContinuous = metadata.isContinuous(featureIndex)
- if (isFeatureContinuous) {
+ val numSplits = metadata.numSplits(featureIndex)
+ val numBins = metadata.numBins(featureIndex)
+ if (metadata.isContinuous(featureIndex)) {
+ val numSamples = sampledInput.length
+ splits(featureIndex) = new Array[Split](numSplits)
+ bins(featureIndex) = new Array[Bin](numBins)
val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
- val stride: Double = numSamples.toDouble / numBins
+ val stride: Double = numSamples.toDouble / metadata.numBins(featureIndex)
logDebug("stride = " + stride)
- for (index <- 0 until numBins - 1) {
- val sampleIndex = index * stride.toInt
+ for (splitIndex <- 0 until numSplits) {
+ val sampleIndex = splitIndex * stride.toInt
// Set threshold halfway in between 2 samples.
val threshold = (featureSamples(sampleIndex) + featureSamples(sampleIndex + 1)) / 2.0
- val split = new Split(featureIndex, threshold, Continuous, List())
- splits(featureIndex)(index) = split
+ splits(featureIndex)(splitIndex) =
+ new Split(featureIndex, threshold, Continuous, List())
}
- } else { // Categorical feature
- val featureCategories = metadata.featureArity(featureIndex)
-
- // Use different bin/split calculation strategy for categorical features in multiclass
- // classification that satisfy the space constraint.
+ bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
+ splits(featureIndex)(0), Continuous, Double.MinValue)
+ for (splitIndex <- 1 until numSplits) {
+ bins(featureIndex)(splitIndex) =
+ new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex),
+ Continuous, Double.MinValue)
+ }
+ bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1),
+ new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue)
+ } else {
+ // Categorical feature
+ val featureArity = metadata.featureArity(featureIndex)
if (metadata.isUnordered(featureIndex)) {
- // 2^(maxFeatureValue- 1) - 1 combinations
- var index = 0
- while (index < (1 << featureCategories - 1) - 1) {
- val categories: List[Double]
- = extractMultiClassCategories(index + 1, featureCategories)
- splits(featureIndex)(index)
- = new Split(featureIndex, Double.MinValue, Categorical, categories)
- bins(featureIndex)(index) = {
- if (index == 0) {
+ // TODO: The second half of the bins are unused. Actually, we could just use
+ // splits and not build bins for unordered features. That should be part of
+ // a later PR since it will require changing other code (using splits instead
+ // of bins in a few places).
+ // Unordered features
+ // 2^(maxFeatureValue - 1) - 1 combinations
+ splits(featureIndex) = new Array[Split](numSplits)
+ bins(featureIndex) = new Array[Bin](numBins)
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ val categories: List[Double] =
+ extractMultiClassCategories(splitIndex + 1, featureArity)
+ splits(featureIndex)(splitIndex) =
+ new Split(featureIndex, Double.MinValue, Categorical, categories)
+ bins(featureIndex)(splitIndex) = {
+ if (splitIndex == 0) {
new Bin(
new DummyCategoricalSplit(featureIndex, Categorical),
splits(featureIndex)(0),
@@ -1395,96 +1026,24 @@ object DecisionTree extends Serializable with Logging {
Double.MinValue)
} else {
new Bin(
- splits(featureIndex)(index - 1),
- splits(featureIndex)(index),
+ splits(featureIndex)(splitIndex - 1),
+ splits(featureIndex)(splitIndex),
Categorical,
Double.MinValue)
}
}
- index += 1
- }
- } else { // ordered feature
- /* For a given categorical feature, use a subsample of the data
- * to choose how to arrange possible splits.
- * This examines each category and computes a centroid.
- * These centroids are later used to sort the possible splits.
- * centroidForCategories is a mapping: category (for the given feature) --> centroid
- */
- val centroidForCategories = {
- if (isMulticlass) {
- // For categorical variables in multiclass classification,
- // each bin is a category. The bins are sorted and they
- // are ordered by calculating the impurity of their corresponding labels.
- sampledInput.map(lp => (lp.features(featureIndex), lp.label))
- .groupBy(_._1)
- .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble))
- .map(x => (x._1, x._2.values.toArray))
- .map(x => (x._1, metadata.impurity.calculate(x._2, x._2.sum)))
- } else { // regression or binary classification
- // For categorical variables in regression and binary classification,
- // each bin is a category. The bins are sorted and they
- // are ordered by calculating the centroid of their corresponding labels.
- sampledInput.map(lp => (lp.features(featureIndex), lp.label))
- .groupBy(_._1)
- .mapValues(x => x.map(_._2).sum / x.map(_._1).length)
- }
- }
-
- logDebug("centroid for categories = " + centroidForCategories.mkString(","))
-
- // Check for missing categorical variables and putting them last in the sorted list.
- val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]()
- for (i <- 0 until featureCategories) {
- if (centroidForCategories.contains(i)) {
- fullCentroidForCategories(i) = centroidForCategories(i)
- } else {
- fullCentroidForCategories(i) = Double.MaxValue
- }
- }
-
- // bins sorted by centroids
- val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2)
-
- logDebug("centroid for categorical variable = " + categoriesSortedByCentroid)
-
- var categoriesForSplit = List[Double]()
- categoriesSortedByCentroid.iterator.zipWithIndex.foreach {
- case ((key, value), index) =>
- categoriesForSplit = key :: categoriesForSplit
- splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue,
- Categorical, categoriesForSplit)
- bins(featureIndex)(index) = {
- if (index == 0) {
- new Bin(new DummyCategoricalSplit(featureIndex, Categorical),
- splits(featureIndex)(0), Categorical, key)
- } else {
- new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index),
- Categorical, key)
- }
- }
+ splitIndex += 1
}
+ } else {
+ // Ordered features
+ // Bins correspond to feature values, so we do not need to compute splits or bins
+ // beforehand. Splits are constructed as needed during training.
+ splits(featureIndex) = new Array[Split](0)
+ bins(featureIndex) = new Array[Bin](0)
}
}
featureIndex += 1
}
-
- // Find all bins.
- featureIndex = 0
- while (featureIndex < numFeatures) {
- val isFeatureContinuous = metadata.isContinuous(featureIndex)
- if (isFeatureContinuous) { // Bins for categorical variables are already assigned.
- bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
- splits(featureIndex)(0), Continuous, Double.MinValue)
- for (index <- 1 until numBins - 1) {
- val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index),
- Continuous, Double.MinValue)
- bins(featureIndex)(index) = bin
- }
- bins(featureIndex)(numBins-1) = new Bin(splits(featureIndex)(numBins-2),
- new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue)
- }
- featureIndex += 1
- }
(splits, bins)
case MinMax =>
throw new UnsupportedOperationException("minmax not supported yet.")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
new file mode 100644
index 0000000000..866d85a79b
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impl
+
+import org.apache.spark.mllib.tree.impurity._
+
+/**
+ * DecisionTree statistics aggregator.
+ * This holds a flat array of statistics for a set of (nodes, features, bins)
+ * and helps with indexing.
+ */
+private[tree] class DTStatsAggregator(
+ val metadata: DecisionTreeMetadata,
+ val numNodes: Int) extends Serializable {
+
+ /**
+ * [[ImpurityAggregator]] instance specifying the impurity type.
+ */
+ val impurityAggregator: ImpurityAggregator = metadata.impurity match {
+ case Gini => new GiniAggregator(metadata.numClasses)
+ case Entropy => new EntropyAggregator(metadata.numClasses)
+ case Variance => new VarianceAggregator()
+ case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")
+ }
+
+ /**
+ * Number of elements (Double values) used for the sufficient statistics of each bin.
+ */
+ val statsSize: Int = impurityAggregator.statsSize
+
+ val numFeatures: Int = metadata.numFeatures
+
+ /**
+ * Number of bins for each feature. This is indexed by the feature index.
+ */
+ val numBins: Array[Int] = metadata.numBins
+
+ /**
+ * Number of splits for the given feature.
+ */
+ def numSplits(featureIndex: Int): Int = metadata.numSplits(featureIndex)
+
+ /**
+ * Indicator for each feature of whether that feature is an unordered feature.
+ * TODO: Is Array[Boolean] any faster?
+ */
+ def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex)
+
+ /**
+ * Offset for each feature for calculating indices into the [[allStats]] array.
+ */
+ private val featureOffsets: Array[Int] = {
+ def featureOffsetsCalc(total: Int, featureIndex: Int): Int = {
+ if (isUnordered(featureIndex)) {
+ total + 2 * numBins(featureIndex)
+ } else {
+ total + numBins(featureIndex)
+ }
+ }
+ Range(0, numFeatures).scanLeft(0)(featureOffsetsCalc).map(statsSize * _).toArray
+ }
+
+ /**
+ * Number of elements for each node, corresponding to stride between nodes in [[allStats]].
+ */
+ private val nodeStride: Int = featureOffsets.last
+
+ /**
+ * Total number of elements stored in this aggregator.
+ */
+ val allStatsSize: Int = numNodes * nodeStride
+
+ /**
+ * Flat array of elements.
+ * Index for start of stats for a (node, feature, bin) is:
+ * index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize
+ * Note: For unordered features, the left child stats have binIndex in [0, numBins(featureIndex))
+ * and the right child stats in [numBins(featureIndex), 2 * numBins(featureIndex))
+ */
+ val allStats: Array[Double] = new Array[Double](allStatsSize)
+
+ /**
+ * Get an [[ImpurityCalculator]] for a given (node, feature, bin).
+ * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
+ * from [[getNodeFeatureOffset]].
+ * For unordered features, this is a pre-computed
+ * (node, feature, left/right child) offset from
+ * [[getLeftRightNodeFeatureOffsets]].
+ */
+ def getImpurityCalculator(nodeFeatureOffset: Int, binIndex: Int): ImpurityCalculator = {
+ impurityAggregator.getCalculator(allStats, nodeFeatureOffset + binIndex * statsSize)
+ }
+
+ /**
+ * Update the stats for a given (node, feature, bin) for ordered features, using the given label.
+ */
+ def update(nodeIndex: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = {
+ val i = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize
+ impurityAggregator.update(allStats, i, label)
+ }
+
+ /**
+ * Pre-compute node offset for use with [[nodeUpdate]].
+ */
+ def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride
+
+ /**
+ * Faster version of [[update]].
+ * Update the stats for a given (node, feature, bin) for ordered features, using the given label.
+ * @param nodeOffset Pre-computed node offset from [[getNodeOffset]].
+ */
+ def nodeUpdate(nodeOffset: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = {
+ val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize
+ impurityAggregator.update(allStats, i, label)
+ }
+
+ /**
+ * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
+ * For ordered features only.
+ */
+ def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
+ require(!isUnordered(featureIndex),
+ s"DTStatsAggregator.getNodeFeatureOffset is for ordered features only, but was called" +
+ s" for unordered feature $featureIndex.")
+ nodeIndex * nodeStride + featureOffsets(featureIndex)
+ }
+
+ /**
+ * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
+ * For unordered features only.
+ */
+ def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, Int) = {
+ require(isUnordered(featureIndex),
+ s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," +
+ s" but was called for ordered feature $featureIndex.")
+ val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex)
+ (baseOffset, baseOffset + numBins(featureIndex) * statsSize)
+ }
+
+ /**
+ * Faster version of [[update]].
+ * Update the stats for a given (node, feature, bin), using the given label.
+ * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
+ * from [[getNodeFeatureOffset]].
+ * For unordered features, this is a pre-computed
+ * (node, feature, left/right child) offset from
+ * [[getLeftRightNodeFeatureOffsets]].
+ */
+ def nodeFeatureUpdate(nodeFeatureOffset: Int, binIndex: Int, label: Double): Unit = {
+ impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label)
+ }
+
+ /**
+ * For a given (node, feature), merge the stats for two bins.
+ * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
+ * from [[getNodeFeatureOffset]].
+ * For unordered features, this is a pre-computed
+ * (node, feature, left/right child) offset from
+ * [[getLeftRightNodeFeatureOffsets]].
+ * @param binIndex The other bin is merged into this bin.
+ * @param otherBinIndex This bin is not modified.
+ */
+ def mergeForNodeFeature(nodeFeatureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = {
+ impurityAggregator.merge(allStats, nodeFeatureOffset + binIndex * statsSize,
+ nodeFeatureOffset + otherBinIndex * statsSize)
+ }
+
+ /**
+ * Merge this aggregator with another, and returns this aggregator.
+ * This method modifies this aggregator in-place.
+ */
+ def merge(other: DTStatsAggregator): DTStatsAggregator = {
+ require(allStatsSize == other.allStatsSize,
+ s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors."
+ + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.")
+ var i = 0
+ // TODO: Test BLAS.axpy
+ while (i < allStatsSize) {
+ allStats(i) += other.allStats(i)
+ i += 1
+ }
+ this
+ }
+
+}
+
+private[tree] object DTStatsAggregator extends Serializable {
+
+ /**
+ * Combines two aggregates (modifying the first) and returns the combination.
+ */
+ def binCombOp(
+ agg1: DTStatsAggregator,
+ agg2: DTStatsAggregator): DTStatsAggregator = {
+ agg1.merge(agg2)
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index d9eda354dc..e95add7558 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -26,14 +26,15 @@ import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impurity.Impurity
import org.apache.spark.rdd.RDD
-
/**
* Learning and dataset metadata for DecisionTree.
*
* @param numClasses For classification: labels can take values {0, ..., numClasses - 1}.
* For regression: fixed at 0 (no meaning).
+ * @param maxBins Maximum number of bins, for all features.
* @param featureArity Map: categorical feature index --> arity.
* I.e., the feature takes values in {0, ..., arity - 1}.
+ * @param numBins Number of bins for each feature.
*/
private[tree] class DecisionTreeMetadata(
val numFeatures: Int,
@@ -42,6 +43,7 @@ private[tree] class DecisionTreeMetadata(
val maxBins: Int,
val featureArity: Map[Int, Int],
val unorderedFeatures: Set[Int],
+ val numBins: Array[Int],
val impurity: Impurity,
val quantileStrategy: QuantileStrategy) extends Serializable {
@@ -57,10 +59,26 @@ private[tree] class DecisionTreeMetadata(
def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex)
+ /**
+ * Number of splits for the given feature.
+ * For unordered features, there are 2 bins per split.
+ * For ordered features, there is 1 more bin than split.
+ */
+ def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) {
+ numBins(featureIndex) >> 1
+ } else {
+ numBins(featureIndex) - 1
+ }
+
}
private[tree] object DecisionTreeMetadata {
+ /**
+ * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters.
+ * This computes which categorical features will be ordered vs. unordered,
+ * as well as the number of splits and bins for each feature.
+ */
def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = {
val numFeatures = input.take(1)(0).features.size
@@ -70,32 +88,55 @@ private[tree] object DecisionTreeMetadata {
case Regression => 0
}
- val maxBins = math.min(strategy.maxBins, numExamples).toInt
- val log2MaxBinsp1 = math.log(maxBins + 1) / math.log(2.0)
+ val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt
+
+ // We check the number of bins here against maxPossibleBins.
+ // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified
+ // based on the number of training examples.
+ if (strategy.categoricalFeaturesInfo.nonEmpty) {
+ val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max
+ require(maxCategoriesPerFeature <= maxPossibleBins,
+ s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories " +
+ s"in categorical features (= $maxCategoriesPerFeature)")
+ }
val unorderedFeatures = new mutable.HashSet[Int]()
+ val numBins = Array.fill[Int](numFeatures)(maxPossibleBins)
if (numClasses > 2) {
- strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
- if (k - 1 < log2MaxBinsp1) {
- // Note: The above check is equivalent to checking:
- // numUnorderedBins = (1 << k - 1) - 1 < maxBins
- unorderedFeatures.add(f)
+ // Multiclass classification
+ val maxCategoriesForUnorderedFeature =
+ ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt
+ strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
+ // Decide if some categorical features should be treated as unordered features,
+ // which require 2 * ((1 << numCategories - 1) - 1) bins.
+ // We do this check with log values to prevent overflows in case numCategories is large.
+ // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins
+ if (numCategories <= maxCategoriesForUnorderedFeature) {
+ unorderedFeatures.add(featureIndex)
+ numBins(featureIndex) = numUnorderedBins(numCategories)
} else {
- // TODO: Allow this case, where we simply will know nothing about some categories?
- require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
- s"in categorical features (>= $k)")
+ numBins(featureIndex) = numCategories
}
}
} else {
- strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
- require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
- s"in categorical features (>= $k)")
+ // Binary classification or regression
+ strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
+ numBins(featureIndex) = numCategories
}
}
- new DecisionTreeMetadata(numFeatures, numExamples, numClasses, maxBins,
- strategy.categoricalFeaturesInfo, unorderedFeatures.toSet,
+ new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
+ strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
strategy.impurity, strategy.quantileCalculationStrategy)
}
+ /**
+ * Given the arity of a categorical feature (arity = number of categories),
+ * return the number of bins for the feature if it is to be treated as an unordered feature.
+ * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets;
+ * there are math.pow(2, arity - 1) - 1 such splits.
+ * Each split has 2 corresponding bins.
+ */
+ def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1)
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
index 170e43e222..35e361ae30 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
@@ -48,54 +48,63 @@ private[tree] object TreePoint {
* binning feature values in preparation for DecisionTree training.
* @param input Input dataset.
* @param bins Bins for features, of size (numFeatures, numBins).
- * @param metadata Learning and dataset metadata
+ * @param metadata Learning and dataset metadata
* @return TreePoint dataset representation
*/
def convertToTreeRDD(
input: RDD[LabeledPoint],
bins: Array[Array[Bin]],
metadata: DecisionTreeMetadata): RDD[TreePoint] = {
+ // Construct arrays for featureArity and isUnordered for efficiency in the inner loop.
+ val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
+ val isUnordered: Array[Boolean] = new Array[Boolean](metadata.numFeatures)
+ var featureIndex = 0
+ while (featureIndex < metadata.numFeatures) {
+ featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0)
+ isUnordered(featureIndex) = metadata.isUnordered(featureIndex)
+ featureIndex += 1
+ }
input.map { x =>
- TreePoint.labeledPointToTreePoint(x, bins, metadata)
+ TreePoint.labeledPointToTreePoint(x, bins, featureArity, isUnordered)
}
}
/**
* Convert one LabeledPoint into its TreePoint representation.
* @param bins Bins for features, of size (numFeatures, numBins).
+ * @param featureArity Array indexed by feature, with value 0 for continuous and numCategories
+ * for categorical features.
+ * @param isUnordered Array index by feature, with value true for unordered categorical features.
*/
private def labeledPointToTreePoint(
labeledPoint: LabeledPoint,
bins: Array[Array[Bin]],
- metadata: DecisionTreeMetadata): TreePoint = {
-
+ featureArity: Array[Int],
+ isUnordered: Array[Boolean]): TreePoint = {
val numFeatures = labeledPoint.features.size
- val numBins = bins(0).size
val arr = new Array[Int](numFeatures)
var featureIndex = 0
while (featureIndex < numFeatures) {
- arr(featureIndex) = findBin(featureIndex, labeledPoint, metadata.isContinuous(featureIndex),
- metadata.isUnordered(featureIndex), bins, metadata.featureArity)
+ arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex),
+ isUnordered(featureIndex), bins)
featureIndex += 1
}
-
new TreePoint(labeledPoint.label, arr)
}
/**
* Find bin for one (labeledPoint, feature).
*
+ * @param featureArity 0 for continuous features; number of categories for categorical features.
* @param isUnorderedFeature (only applies if feature is categorical)
* @param bins Bins for features, of size (numFeatures, numBins).
- * @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity
*/
private def findBin(
featureIndex: Int,
labeledPoint: LabeledPoint,
- isFeatureContinuous: Boolean,
+ featureArity: Int,
isUnorderedFeature: Boolean,
- bins: Array[Array[Bin]],
- categoricalFeaturesInfo: Map[Int, Int]): Int = {
+ bins: Array[Array[Bin]]): Int = {
/**
* Binary search helper method for continuous feature.
@@ -121,44 +130,7 @@ private[tree] object TreePoint {
-1
}
- /**
- * Sequential search helper method to find bin for categorical feature in multiclass
- * classification. The category is returned since each category can belong to multiple
- * splits. The actual left/right child allocation per split is performed in the
- * sequential phase of the bin aggregate operation.
- */
- def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = {
- labeledPoint.features(featureIndex).toInt
- }
-
- /**
- * Sequential search helper method to find bin for categorical feature
- * (for classification and regression).
- */
- def sequentialBinSearchForOrderedCategoricalFeature(): Int = {
- val featureCategories = categoricalFeaturesInfo(featureIndex)
- val featureValue = labeledPoint.features(featureIndex)
- var binIndex = 0
- while (binIndex < featureCategories) {
- val bin = bins(featureIndex)(binIndex)
- val categories = bin.highSplit.categories
- if (categories.contains(featureValue)) {
- return binIndex
- }
- binIndex += 1
- }
- if (featureValue < 0 || featureValue >= featureCategories) {
- throw new IllegalArgumentException(
- s"DecisionTree given invalid data:" +
- s" Feature $featureIndex is categorical with values in" +
- s" {0,...,${featureCategories - 1}," +
- s" but a data point gives it value $featureValue.\n" +
- " Bad data point: " + labeledPoint.toString)
- }
- -1
- }
-
- if (isFeatureContinuous) {
+ if (featureArity == 0) {
// Perform binary search for finding bin for continuous features.
val binIndex = binarySearchForBins()
if (binIndex == -1) {
@@ -168,18 +140,17 @@ private[tree] object TreePoint {
}
binIndex
} else {
- // Perform sequential search to find bin for categorical features.
- val binIndex = if (isUnorderedFeature) {
- sequentialBinSearchForUnorderedCategoricalFeatureInClassification()
- } else {
- sequentialBinSearchForOrderedCategoricalFeature()
- }
- if (binIndex == -1) {
- throw new RuntimeException("No bin was found for categorical feature." +
- " This error can occur when given invalid data values (such as NaN)." +
- s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}")
+ // Categorical feature bins are indexed by feature values.
+ val featureValue = labeledPoint.features(featureIndex)
+ if (featureValue < 0 || featureValue >= featureArity) {
+ throw new IllegalArgumentException(
+ s"DecisionTree given invalid data:" +
+ s" Feature $featureIndex is categorical with values in" +
+ s" {0,...,${featureArity - 1}," +
+ s" but a data point gives it value $featureValue.\n" +
+ " Bad data point: " + labeledPoint.toString)
}
- binIndex
+ featureValue.toInt
}
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 96d2471e1f..1c8afc2d0f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -74,3 +74,87 @@ object Entropy extends Impurity {
def instance = this
}
+
+/**
+ * Class for updating views of a vector of sufficient statistics,
+ * in order to compute impurity from a sample.
+ * Note: Instances of this class do not hold the data; they operate on views of the data.
+ * @param numClasses Number of classes for label.
+ */
+private[tree] class EntropyAggregator(numClasses: Int)
+ extends ImpurityAggregator(numClasses) with Serializable {
+
+ /**
+ * Update stats for one (node, feature, bin) with the given label.
+ * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
+ * @param offset Start index of stats for this (node, feature, bin).
+ */
+ def update(allStats: Array[Double], offset: Int, label: Double): Unit = {
+ if (label >= statsSize) {
+ throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
+ s" but requires label < numClasses (= $statsSize).")
+ }
+ allStats(offset + label.toInt) += 1
+ }
+
+ /**
+ * Get an [[ImpurityCalculator]] for a (node, feature, bin).
+ * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
+ * @param offset Start index of stats for this (node, feature, bin).
+ */
+ def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = {
+ new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray)
+ }
+
+}
+
+/**
+ * Stores statistics for one (node, feature, bin) for calculating impurity.
+ * Unlike [[EntropyAggregator]], this class stores its own data and is for a specific
+ * (node, feature, bin).
+ * @param stats Array of sufficient statistics for a (node, feature, bin).
+ */
+private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
+
+ /**
+ * Make a deep copy of this [[ImpurityCalculator]].
+ */
+ def copy: EntropyCalculator = new EntropyCalculator(stats.clone())
+
+ /**
+ * Calculate the impurity from the stored sufficient statistics.
+ */
+ def calculate(): Double = Entropy.calculate(stats, stats.sum)
+
+ /**
+ * Number of data points accounted for in the sufficient statistics.
+ */
+ def count: Long = stats.sum.toLong
+
+ /**
+ * Prediction which should be made based on the sufficient statistics.
+ */
+ def predict: Double = if (count == 0) {
+ 0
+ } else {
+ indexOfLargestArrayElement(stats)
+ }
+
+ /**
+ * Probability of the label given by [[predict]].
+ */
+ override def prob(label: Double): Double = {
+ val lbl = label.toInt
+ require(lbl < stats.length,
+ s"EntropyCalculator.prob given invalid label: $lbl (should be < ${stats.length}")
+ val cnt = count
+ if (cnt == 0) {
+ 0
+ } else {
+ stats(lbl) / cnt
+ }
+ }
+
+ override def toString: String = s"EntropyCalculator(stats = [${stats.mkString(", ")}])"
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index d586f44904..5cfdf345d1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -70,3 +70,87 @@ object Gini extends Impurity {
def instance = this
}
+
+/**
+ * Class for updating views of a vector of sufficient statistics,
+ * in order to compute impurity from a sample.
+ * Note: Instances of this class do not hold the data; they operate on views of the data.
+ * @param numClasses Number of classes for label.
+ */
+private[tree] class GiniAggregator(numClasses: Int)
+ extends ImpurityAggregator(numClasses) with Serializable {
+
+ /**
+ * Update stats for one (node, feature, bin) with the given label.
+ * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
+ * @param offset Start index of stats for this (node, feature, bin).
+ */
+ def update(allStats: Array[Double], offset: Int, label: Double): Unit = {
+ if (label >= statsSize) {
+ throw new IllegalArgumentException(s"GiniAggregator given label $label" +
+ s" but requires label < numClasses (= $statsSize).")
+ }
+ allStats(offset + label.toInt) += 1
+ }
+
+ /**
+ * Get an [[ImpurityCalculator]] for a (node, feature, bin).
+ * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
+ * @param offset Start index of stats for this (node, feature, bin).
+ */
+ def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = {
+ new GiniCalculator(allStats.view(offset, offset + statsSize).toArray)
+ }
+
+}
+
+/**
+ * Stores statistics for one (node, feature, bin) for calculating impurity.
+ * Unlike [[GiniAggregator]], this class stores its own data and is for a specific
+ * (node, feature, bin).
+ * @param stats Array of sufficient statistics for a (node, feature, bin).
+ */
+private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
+
+ /**
+ * Make a deep copy of this [[ImpurityCalculator]].
+ */
+ def copy: GiniCalculator = new GiniCalculator(stats.clone())
+
+ /**
+ * Calculate the impurity from the stored sufficient statistics.
+ */
+ def calculate(): Double = Gini.calculate(stats, stats.sum)
+
+ /**
+ * Number of data points accounted for in the sufficient statistics.
+ */
+ def count: Long = stats.sum.toLong
+
+ /**
+ * Prediction which should be made based on the sufficient statistics.
+ */
+ def predict: Double = if (count == 0) {
+ 0
+ } else {
+ indexOfLargestArrayElement(stats)
+ }
+
+ /**
+ * Probability of the label given by [[predict]].
+ */
+ override def prob(label: Double): Double = {
+ val lbl = label.toInt
+ require(lbl < stats.length,
+ s"GiniCalculator.prob given invalid label: $lbl (should be < ${stats.length}")
+ val cnt = count
+ if (cnt == 0) {
+ 0
+ } else {
+ stats(lbl) / cnt
+ }
+ }
+
+ override def toString: String = s"GiniCalculator(stats = [${stats.mkString(", ")}])"
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index 92b0c7b4a6..5a047d6cb5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -22,6 +22,9 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
/**
* :: Experimental ::
* Trait for calculating information gain.
+ * This trait is used for
+ * (a) setting the impurity parameter in [[org.apache.spark.mllib.tree.configuration.Strategy]]
+ * (b) calculating impurity values from sufficient statistics.
*/
@Experimental
trait Impurity extends Serializable {
@@ -47,3 +50,127 @@ trait Impurity extends Serializable {
@DeveloperApi
def calculate(count: Double, sum: Double, sumSquares: Double): Double
}
+
+/**
+ * Interface for updating views of a vector of sufficient statistics,
+ * in order to compute impurity from a sample.
+ * Note: Instances of this class do not hold the data; they operate on views of the data.
+ * @param statsSize Length of the vector of sufficient statistics for one bin.
+ */
+private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Serializable {
+
+ /**
+ * Merge the stats from one bin into another.
+ * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
+ * @param offset Start index of stats for (node, feature, bin) which is modified by the merge.
+ * @param otherOffset Start index of stats for (node, feature, other bin) which is not modified.
+ */
+ def merge(allStats: Array[Double], offset: Int, otherOffset: Int): Unit = {
+ var i = 0
+ while (i < statsSize) {
+ allStats(offset + i) += allStats(otherOffset + i)
+ i += 1
+ }
+ }
+
+ /**
+ * Update stats for one (node, feature, bin) with the given label.
+ * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
+ * @param offset Start index of stats for this (node, feature, bin).
+ */
+ def update(allStats: Array[Double], offset: Int, label: Double): Unit
+
+ /**
+ * Get an [[ImpurityCalculator]] for a (node, feature, bin).
+ * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
+ * @param offset Start index of stats for this (node, feature, bin).
+ */
+ def getCalculator(allStats: Array[Double], offset: Int): ImpurityCalculator
+
+}
+
+/**
+ * Stores statistics for one (node, feature, bin) for calculating impurity.
+ * Unlike [[ImpurityAggregator]], this class stores its own data and is for a specific
+ * (node, feature, bin).
+ * @param stats Array of sufficient statistics for a (node, feature, bin).
+ */
+private[tree] abstract class ImpurityCalculator(val stats: Array[Double]) {
+
+ /**
+ * Make a deep copy of this [[ImpurityCalculator]].
+ */
+ def copy: ImpurityCalculator
+
+ /**
+ * Calculate the impurity from the stored sufficient statistics.
+ */
+ def calculate(): Double
+
+ /**
+ * Add the stats from another calculator into this one, modifying and returning this calculator.
+ */
+ def add(other: ImpurityCalculator): ImpurityCalculator = {
+ require(stats.size == other.stats.size,
+ s"Two ImpurityCalculator instances cannot be added with different counts sizes." +
+ s" Sizes are ${stats.size} and ${other.stats.size}.")
+ var i = 0
+ while (i < other.stats.size) {
+ stats(i) += other.stats(i)
+ i += 1
+ }
+ this
+ }
+
+ /**
+ * Subtract the stats from another calculator from this one, modifying and returning this
+ * calculator.
+ */
+ def subtract(other: ImpurityCalculator): ImpurityCalculator = {
+ require(stats.size == other.stats.size,
+ s"Two ImpurityCalculator instances cannot be subtracted with different counts sizes." +
+ s" Sizes are ${stats.size} and ${other.stats.size}.")
+ var i = 0
+ while (i < other.stats.size) {
+ stats(i) -= other.stats(i)
+ i += 1
+ }
+ this
+ }
+
+ /**
+ * Number of data points accounted for in the sufficient statistics.
+ */
+ def count: Long
+
+ /**
+ * Prediction which should be made based on the sufficient statistics.
+ */
+ def predict: Double
+
+ /**
+ * Probability of the label given by [[predict]], or -1 if no probability is available.
+ */
+ def prob(label: Double): Double = -1
+
+ /**
+ * Return the index of the largest array element.
+ * Fails if the array is empty.
+ */
+ protected def indexOfLargestArrayElement(array: Array[Double]): Int = {
+ val result = array.foldLeft(-1, Double.MinValue, 0) {
+ case ((maxIndex, maxValue, currentIndex), currentValue) =>
+ if (currentValue > maxValue) {
+ (currentIndex, currentValue, currentIndex + 1)
+ } else {
+ (maxIndex, maxValue, currentIndex + 1)
+ }
+ }
+ if (result._1 < 0) {
+ throw new RuntimeException("ImpurityCalculator internal error:" +
+ " indexOfLargestArrayElement failed")
+ }
+ result._1
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index f7d99a40eb..e9ccecb1b8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -61,3 +61,75 @@ object Variance extends Impurity {
def instance = this
}
+
+/**
+ * Class for updating views of a vector of sufficient statistics,
+ * in order to compute impurity from a sample.
+ * Note: Instances of this class do not hold the data; they operate on views of the data.
+ */
+private[tree] class VarianceAggregator()
+ extends ImpurityAggregator(statsSize = 3) with Serializable {
+
+ /**
+ * Update stats for one (node, feature, bin) with the given label.
+ * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
+ * @param offset Start index of stats for this (node, feature, bin).
+ */
+ def update(allStats: Array[Double], offset: Int, label: Double): Unit = {
+ allStats(offset) += 1
+ allStats(offset + 1) += label
+ allStats(offset + 2) += label * label
+ }
+
+ /**
+ * Get an [[ImpurityCalculator]] for a (node, feature, bin).
+ * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
+ * @param offset Start index of stats for this (node, feature, bin).
+ */
+ def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = {
+ new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray)
+ }
+
+}
+
+/**
+ * Stores statistics for one (node, feature, bin) for calculating impurity.
+ * Unlike [[GiniAggregator]], this class stores its own data and is for a specific
+ * (node, feature, bin).
+ * @param stats Array of sufficient statistics for a (node, feature, bin).
+ */
+private[tree] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
+
+ require(stats.size == 3,
+ s"VarianceCalculator requires sufficient statistics array stats to be of length 3," +
+ s" but was given array of length ${stats.size}.")
+
+ /**
+ * Make a deep copy of this [[ImpurityCalculator]].
+ */
+ def copy: VarianceCalculator = new VarianceCalculator(stats.clone())
+
+ /**
+ * Calculate the impurity from the stored sufficient statistics.
+ */
+ def calculate(): Double = Variance.calculate(stats(0), stats(1), stats(2))
+
+ /**
+ * Number of data points accounted for in the sufficient statistics.
+ */
+ def count: Long = stats(0).toLong
+
+ /**
+ * Prediction which should be made based on the sufficient statistics.
+ */
+ def predict: Double = if (count == 0) {
+ 0
+ } else {
+ stats(1) / count
+ }
+
+ override def toString: String = {
+ s"VarianceAggregator(cnt = ${stats(0)}, sum = ${stats(1)}, sum2 = ${stats(2)})"
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
index af35d88f71..0cad473782 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.model
import org.apache.spark.mllib.tree.configuration.FeatureType._
/**
- * Used for "binning" the features bins for faster best split calculation.
+ * Used for "binning" the feature values for faster best split calculation.
*
* For a continuous feature, the bin is determined by a low and a high split,
* where an example with featureValue falls into the bin s.t.
@@ -30,13 +30,16 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
* bins, splits, and feature values. The bin is determined by category/feature value.
* However, the bins are not necessarily ordered by feature value;
* they are ordered using impurity.
+ *
* For unordered categorical features, there is a 1-1 correspondence between bins, splits,
* where bins and splits correspond to subsets of feature values (in highSplit.categories).
+ * An unordered feature with k categories uses (1 << k - 1) - 1 bins, corresponding to all
+ * partitionings of categories into 2 disjoint, non-empty sets.
*
* @param lowSplit signifying the lower threshold for the continuous feature to be
* accepted in the bin
* @param highSplit signifying the upper threshold for the continuous feature to be
- * accepted in the bin
+ * accepted in the bin
* @param featureType type of feature -- categorical or continuous
* @param category categorical label value accepted in the bin for ordered features
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 0eee626278..5b8a4cbed2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -24,8 +24,13 @@ import org.apache.spark.mllib.linalg.Vector
/**
* :: DeveloperApi ::
- * Node in a decision tree
- * @param id integer node id
+ * Node in a decision tree.
+ *
+ * About node indexing:
+ * Nodes are indexed from 1. Node 1 is the root; nodes 2, 3 are the left, right children.
+ * Node index 0 is not used.
+ *
+ * @param id integer node id, from 1
* @param predict predicted value at the node
* @param isLeaf whether the leaf is a node
* @param split split to calculate left and right nodes
@@ -51,17 +56,13 @@ class Node (
* @param nodes array of nodes
*/
def build(nodes: Array[Node]): Unit = {
-
- logDebug("building node " + id + " at level " +
- (scala.math.log(id + 1)/scala.math.log(2)).toInt )
+ logDebug("building node " + id + " at level " + Node.indexToLevel(id))
logDebug("id = " + id + ", split = " + split)
logDebug("stats = " + stats)
logDebug("predict = " + predict)
if (!isLeaf) {
- val leftNodeIndex = id * 2 + 1
- val rightNodeIndex = id * 2 + 2
- leftNode = Some(nodes(leftNodeIndex))
- rightNode = Some(nodes(rightNodeIndex))
+ leftNode = Some(nodes(Node.leftChildIndex(id)))
+ rightNode = Some(nodes(Node.rightChildIndex(id)))
leftNode.get.build(nodes)
rightNode.get.build(nodes)
}
@@ -96,24 +97,20 @@ class Node (
* Get the number of nodes in tree below this node, including leaf nodes.
* E.g., if this is a leaf, returns 0. If both children are leaves, returns 2.
*/
- private[tree] def numDescendants: Int = {
- if (isLeaf) {
- 0
- } else {
- 2 + leftNode.get.numDescendants + rightNode.get.numDescendants
- }
+ private[tree] def numDescendants: Int = if (isLeaf) {
+ 0
+ } else {
+ 2 + leftNode.get.numDescendants + rightNode.get.numDescendants
}
/**
* Get depth of tree from this node.
* E.g.: Depth 0 means this is a leaf node.
*/
- private[tree] def subtreeDepth: Int = {
- if (isLeaf) {
- 0
- } else {
- 1 + math.max(leftNode.get.subtreeDepth, rightNode.get.subtreeDepth)
- }
+ private[tree] def subtreeDepth: Int = if (isLeaf) {
+ 0
+ } else {
+ 1 + math.max(leftNode.get.subtreeDepth, rightNode.get.subtreeDepth)
}
/**
@@ -148,3 +145,49 @@ class Node (
}
}
+
+private[tree] object Node {
+
+ /**
+ * Return the index of the left child of this node.
+ */
+ def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1
+
+ /**
+ * Return the index of the right child of this node.
+ */
+ def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1
+
+ /**
+ * Get the parent index of the given node, or 0 if it is the root.
+ */
+ def parentIndex(nodeIndex: Int): Int = nodeIndex >> 1
+
+ /**
+ * Return the level of a tree which the given node is in.
+ */
+ def indexToLevel(nodeIndex: Int): Int = if (nodeIndex == 0) {
+ throw new IllegalArgumentException(s"0 is not a valid node index.")
+ } else {
+ java.lang.Integer.numberOfTrailingZeros(java.lang.Integer.highestOneBit(nodeIndex))
+ }
+
+ /**
+ * Returns true if this is a left child.
+ * Note: Returns false for the root.
+ */
+ def isLeftChild(nodeIndex: Int): Boolean = nodeIndex > 1 && nodeIndex % 2 == 0
+
+ /**
+ * Return the maximum number of nodes which can be in the given level of the tree.
+ * @param level Level of tree (0 = root).
+ */
+ def maxNodesInLevel(level: Int): Int = 1 << level
+
+ /**
+ * Return the index of the first node in the given level.
+ * @param level Level of tree (0 = root).
+ */
+ def startIndexInLevel(level: Int): Int = 1 << level
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 2f36fd9077..8e556c917b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -21,15 +21,16 @@ import scala.collection.JavaConverters._
import org.scalatest.FunSuite
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node}
-import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.LocalSparkContext
-import org.apache.spark.mllib.regression.LabeledPoint
+
class DecisionTreeSuite extends FunSuite with LocalSparkContext {
@@ -59,12 +60,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
}
- test("split and bin calculation") {
+ test("Binary classification with continuous features: split and bin calculation") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, 3, 2, 100)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(bins.length === 2)
@@ -72,7 +74,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins(0).length === 100)
}
- test("split and bin calculation for categorical variables") {
+ test("Binary classification with binary (ordered) categorical features:" +
+ " split and bin calculation") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
@@ -83,77 +86,20 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
assert(splits.length === 2)
assert(bins.length === 2)
- assert(splits(0).length === 99)
- assert(bins(0).length === 100)
-
- // Check splits.
-
- assert(splits(0)(0).feature === 0)
- assert(splits(0)(0).threshold === Double.MinValue)
- assert(splits(0)(0).featureType === Categorical)
- assert(splits(0)(0).categories.length === 1)
- assert(splits(0)(0).categories.contains(1.0))
-
- assert(splits(0)(1).feature === 0)
- assert(splits(0)(1).threshold === Double.MinValue)
- assert(splits(0)(1).featureType === Categorical)
- assert(splits(0)(1).categories.length === 2)
- assert(splits(0)(1).categories.contains(1.0))
- assert(splits(0)(1).categories.contains(0.0))
-
- assert(splits(0)(2) === null)
-
- assert(splits(1)(0).feature === 1)
- assert(splits(1)(0).threshold === Double.MinValue)
- assert(splits(1)(0).featureType === Categorical)
- assert(splits(1)(0).categories.length === 1)
- assert(splits(1)(0).categories.contains(0.0))
-
- assert(splits(1)(1).feature === 1)
- assert(splits(1)(1).threshold === Double.MinValue)
- assert(splits(1)(1).featureType === Categorical)
- assert(splits(1)(1).categories.length === 2)
- assert(splits(1)(1).categories.contains(1.0))
- assert(splits(1)(1).categories.contains(0.0))
-
- assert(splits(1)(2) === null)
-
- // Check bins.
-
- assert(bins(0)(0).category === 1.0)
- assert(bins(0)(0).lowSplit.categories.length === 0)
- assert(bins(0)(0).highSplit.categories.length === 1)
- assert(bins(0)(0).highSplit.categories.contains(1.0))
-
- assert(bins(0)(1).category === 0.0)
- assert(bins(0)(1).lowSplit.categories.length === 1)
- assert(bins(0)(1).lowSplit.categories.contains(1.0))
- assert(bins(0)(1).highSplit.categories.length === 2)
- assert(bins(0)(1).highSplit.categories.contains(1.0))
- assert(bins(0)(1).highSplit.categories.contains(0.0))
-
- assert(bins(0)(2) === null)
-
- assert(bins(1)(0).category === 0.0)
- assert(bins(1)(0).lowSplit.categories.length === 0)
- assert(bins(1)(0).highSplit.categories.length === 1)
- assert(bins(1)(0).highSplit.categories.contains(0.0))
-
- assert(bins(1)(1).category === 1.0)
- assert(bins(1)(1).lowSplit.categories.length === 1)
- assert(bins(1)(1).lowSplit.categories.contains(0.0))
- assert(bins(1)(1).highSplit.categories.length === 2)
- assert(bins(1)(1).highSplit.categories.contains(0.0))
- assert(bins(1)(1).highSplit.categories.contains(1.0))
-
- assert(bins(1)(2) === null)
+ // no bins or splits pre-computed for ordered categorical features
+ assert(splits(0).length === 0)
+ assert(bins(0).length === 0)
}
- test("split and bin calculations for categorical variables with no sample for one category") {
+ test("Binary classification with 3-ary (ordered) categorical features," +
+ " with no samples for one category") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
@@ -164,104 +110,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-
- // Check splits.
-
- assert(splits(0)(0).feature === 0)
- assert(splits(0)(0).threshold === Double.MinValue)
- assert(splits(0)(0).featureType === Categorical)
- assert(splits(0)(0).categories.length === 1)
- assert(splits(0)(0).categories.contains(1.0))
-
- assert(splits(0)(1).feature === 0)
- assert(splits(0)(1).threshold === Double.MinValue)
- assert(splits(0)(1).featureType === Categorical)
- assert(splits(0)(1).categories.length === 2)
- assert(splits(0)(1).categories.contains(1.0))
- assert(splits(0)(1).categories.contains(0.0))
-
- assert(splits(0)(2).feature === 0)
- assert(splits(0)(2).threshold === Double.MinValue)
- assert(splits(0)(2).featureType === Categorical)
- assert(splits(0)(2).categories.length === 3)
- assert(splits(0)(2).categories.contains(1.0))
- assert(splits(0)(2).categories.contains(0.0))
- assert(splits(0)(2).categories.contains(2.0))
-
- assert(splits(0)(3) === null)
-
- assert(splits(1)(0).feature === 1)
- assert(splits(1)(0).threshold === Double.MinValue)
- assert(splits(1)(0).featureType === Categorical)
- assert(splits(1)(0).categories.length === 1)
- assert(splits(1)(0).categories.contains(0.0))
-
- assert(splits(1)(1).feature === 1)
- assert(splits(1)(1).threshold === Double.MinValue)
- assert(splits(1)(1).featureType === Categorical)
- assert(splits(1)(1).categories.length === 2)
- assert(splits(1)(1).categories.contains(1.0))
- assert(splits(1)(1).categories.contains(0.0))
-
- assert(splits(1)(2).feature === 1)
- assert(splits(1)(2).threshold === Double.MinValue)
- assert(splits(1)(2).featureType === Categorical)
- assert(splits(1)(2).categories.length === 3)
- assert(splits(1)(2).categories.contains(1.0))
- assert(splits(1)(2).categories.contains(0.0))
- assert(splits(1)(2).categories.contains(2.0))
-
- assert(splits(1)(3) === null)
-
- // Check bins.
-
- assert(bins(0)(0).category === 1.0)
- assert(bins(0)(0).lowSplit.categories.length === 0)
- assert(bins(0)(0).highSplit.categories.length === 1)
- assert(bins(0)(0).highSplit.categories.contains(1.0))
-
- assert(bins(0)(1).category === 0.0)
- assert(bins(0)(1).lowSplit.categories.length === 1)
- assert(bins(0)(1).lowSplit.categories.contains(1.0))
- assert(bins(0)(1).highSplit.categories.length === 2)
- assert(bins(0)(1).highSplit.categories.contains(1.0))
- assert(bins(0)(1).highSplit.categories.contains(0.0))
-
- assert(bins(0)(2).category === 2.0)
- assert(bins(0)(2).lowSplit.categories.length === 2)
- assert(bins(0)(2).lowSplit.categories.contains(1.0))
- assert(bins(0)(2).lowSplit.categories.contains(0.0))
- assert(bins(0)(2).highSplit.categories.length === 3)
- assert(bins(0)(2).highSplit.categories.contains(1.0))
- assert(bins(0)(2).highSplit.categories.contains(0.0))
- assert(bins(0)(2).highSplit.categories.contains(2.0))
-
- assert(bins(0)(3) === null)
-
- assert(bins(1)(0).category === 0.0)
- assert(bins(1)(0).lowSplit.categories.length === 0)
- assert(bins(1)(0).highSplit.categories.length === 1)
- assert(bins(1)(0).highSplit.categories.contains(0.0))
-
- assert(bins(1)(1).category === 1.0)
- assert(bins(1)(1).lowSplit.categories.length === 1)
- assert(bins(1)(1).lowSplit.categories.contains(0.0))
- assert(bins(1)(1).highSplit.categories.length === 2)
- assert(bins(1)(1).highSplit.categories.contains(0.0))
- assert(bins(1)(1).highSplit.categories.contains(1.0))
-
- assert(bins(1)(2).category === 2.0)
- assert(bins(1)(2).lowSplit.categories.length === 2)
- assert(bins(1)(2).lowSplit.categories.contains(0.0))
- assert(bins(1)(2).lowSplit.categories.contains(1.0))
- assert(bins(1)(2).highSplit.categories.length === 3)
- assert(bins(1)(2).highSplit.categories.contains(0.0))
- assert(bins(1)(2).highSplit.categories.contains(1.0))
- assert(bins(1)(2).highSplit.categories.contains(2.0))
-
- assert(bins(1)(3) === null)
+ assert(splits.length === 2)
+ assert(bins.length === 2)
+ // no bins or splits pre-computed for ordered categorical features
+ assert(splits(0).length === 0)
+ assert(bins(0).length === 0)
}
test("extract categories from a number for multiclass classification") {
@@ -270,8 +128,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq)
}
- test("split and bin calculations for unordered categorical variables with multiclass " +
- "classification") {
+ test("Multiclass classification with unordered categorical features:" +
+ " split and bin calculations") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
@@ -282,8 +140,15 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 100,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(metadata.isUnordered(featureIndex = 0))
+ assert(metadata.isUnordered(featureIndex = 1))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+ assert(splits.length === 2)
+ assert(bins.length === 2)
+ assert(splits(0).length === 3)
+ assert(bins(0).length === 6)
// Expecting 2^2 - 1 = 3 bins/splits
assert(splits(0)(0).feature === 0)
@@ -321,10 +186,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(1)(2).categories.contains(0.0))
assert(splits(1)(2).categories.contains(1.0))
- assert(splits(0)(3) === null)
- assert(splits(1)(3) === null)
-
-
// Check bins.
assert(bins(0)(0).category === Double.MinValue)
@@ -360,13 +221,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins(1)(2).highSplit.categories.contains(1.0))
assert(bins(1)(2).highSplit.categories.contains(0.0))
- assert(bins(0)(3) === null)
- assert(bins(1)(3) === null)
-
}
- test("split and bin calculations for ordered categorical variables with multiclass " +
- "classification") {
+ test("Multiclass classification with ordered categorical features: split and bin calculations") {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
assert(arr.length === 3000)
val rdd = sc.parallelize(arr)
@@ -377,52 +234,21 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 100,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 10, 1-> 10))
+ // 2^10 - 1 > 100, so categorical features will be ordered
+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-
- // 2^10 - 1 > 100, so categorical variables will be ordered
-
- assert(splits(0)(0).feature === 0)
- assert(splits(0)(0).threshold === Double.MinValue)
- assert(splits(0)(0).featureType === Categorical)
- assert(splits(0)(0).categories.length === 1)
- assert(splits(0)(0).categories.contains(1.0))
-
- assert(splits(0)(1).feature === 0)
- assert(splits(0)(1).threshold === Double.MinValue)
- assert(splits(0)(1).featureType === Categorical)
- assert(splits(0)(1).categories.length === 2)
- assert(splits(0)(1).categories.contains(2.0))
-
- assert(splits(0)(2).feature === 0)
- assert(splits(0)(2).threshold === Double.MinValue)
- assert(splits(0)(2).featureType === Categorical)
- assert(splits(0)(2).categories.length === 3)
- assert(splits(0)(2).categories.contains(2.0))
- assert(splits(0)(2).categories.contains(1.0))
-
- assert(splits(0)(10) === null)
- assert(splits(1)(10) === null)
-
-
- // Check bins.
-
- assert(bins(0)(0).category === 1.0)
- assert(bins(0)(0).lowSplit.categories.length === 0)
- assert(bins(0)(0).highSplit.categories.length === 1)
- assert(bins(0)(0).highSplit.categories.contains(1.0))
- assert(bins(0)(1).category === 2.0)
- assert(bins(0)(1).lowSplit.categories.length === 1)
- assert(bins(0)(1).highSplit.categories.length === 2)
- assert(bins(0)(1).highSplit.categories.contains(1.0))
- assert(bins(0)(1).highSplit.categories.contains(2.0))
-
- assert(bins(0)(10) === null)
-
+ assert(splits.length === 2)
+ assert(bins.length === 2)
+ // no bins or splits pre-computed for ordered categorical features
+ assert(splits(0).length === 0)
+ assert(bins(0).length === 0)
}
- test("classification stump with all categorical variables") {
+ test("Binary classification stump with ordered categorical features") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
@@ -433,15 +259,23 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
maxDepth = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+ assert(splits.length === 2)
+ assert(bins.length === 2)
+ // no bins or splits pre-computed for ordered categorical features
+ assert(splits(0).length === 0)
+ assert(bins(0).length === 0)
+
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0,
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
new Array[Node](0), splits, bins, 10)
val split = bestSplits(0)._1
- assert(split.categories.length === 1)
- assert(split.categories.contains(1.0))
+ assert(split.categories === List(1.0))
assert(split.featureType === Categorical)
assert(split.threshold === Double.MinValue)
@@ -452,7 +286,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(stats.impurity > 0.2)
}
- test("regression stump with all categorical variables") {
+ test("Regression stump with 3-ary (ordered) categorical features") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
@@ -462,10 +296,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
maxDepth = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
+
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0,
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
new Array[Node](0), splits, bins, 10)
val split = bestSplits(0)._1
@@ -480,7 +318,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(stats.impurity > 0.2)
}
- test("regression stump with categorical variables of arity 2") {
+ test("Regression stump with binary (ordered) categorical features") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
@@ -490,6 +328,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
maxDepth = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
val model = DecisionTree.train(rdd, strategy)
validateRegressor(model, arr, 0.0)
@@ -497,12 +338,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(model.depth === 1)
}
- test("stump with fixed label 0 for Gini") {
+ test("Binary classification stump with fixed label 0 for Gini") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
- val strategy = new Strategy(Classification, Gini, 3, 2, 100)
+ val strategy = new Strategy(Classification, Gini, maxDepth = 3,
+ numClassesForClassification = 2, maxBins = 100)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
+
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(splits(0).length === 99)
@@ -512,7 +357,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins(0).length === 100)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0,
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
@@ -521,12 +366,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(0)._2.rightImpurity === 0)
}
- test("stump with fixed label 1 for Gini") {
+ test("Binary classification stump with fixed label 1 for Gini") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
- val strategy = new Strategy(Classification, Gini, 3, 2, 100)
+ val strategy = new Strategy(Classification, Gini, maxDepth = 3,
+ numClassesForClassification = 2, maxBins = 100)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
+
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(splits(0).length === 99)
@@ -536,7 +385,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins(0).length === 100)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0,
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
@@ -546,12 +395,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(0)._2.predict === 1)
}
- test("stump with fixed label 0 for Entropy") {
+ test("Binary classification stump with fixed label 0 for Entropy") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
- val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
+ val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
+ numClassesForClassification = 2, maxBins = 100)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
+
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(splits(0).length === 99)
@@ -561,7 +414,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins(0).length === 100)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0,
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
@@ -571,12 +424,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(0)._2.predict === 0)
}
- test("stump with fixed label 1 for Entropy") {
+ test("Binary classification stump with fixed label 1 for Entropy") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
- val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
+ val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
+ numClassesForClassification = 2, maxBins = 100)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
+
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(splits(0).length === 99)
@@ -586,7 +443,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins(0).length === 100)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0,
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
@@ -596,7 +453,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(0)._2.predict === 1)
}
- test("second level node building with/without groups") {
+ test("Second level node building with vs. without groups") {
val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
@@ -613,12 +470,12 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
// Train a 1-node model
val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100)
val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
- val nodes: Array[Node] = new Array[Node](7)
- nodes(0) = modelOneNode.topNode
- nodes(0).leftNode = None
- nodes(0).rightNode = None
+ val nodes: Array[Node] = new Array[Node](8)
+ nodes(1) = modelOneNode.topNode
+ nodes(1).leftNode = None
+ nodes(1).rightNode = None
- val parentImpurities = Array(0.5, 0.5, 0.5)
+ val parentImpurities = Array(0, 0.5, 0.5, 0.5)
// Single group second level tree construction.
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
@@ -648,16 +505,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
}
}
- test("stump with categorical variables for multiclass classification") {
+ test("Multiclass classification stump with 3-ary (unordered) categorical features") {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
assert(strategy.isMulticlassClassification)
+ assert(metadata.isUnordered(featureIndex = 0))
+ assert(metadata.isUnordered(featureIndex = 1))
+
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
@@ -668,7 +528,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplit.featureType === Categorical)
}
- test("stump with 1 continuous variable for binary classification, to check off-by-1 error") {
+ test("Binary classification stump with 1 continuous feature, to check off-by-1 error") {
val arr = new Array[LabeledPoint](4)
arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0))
arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0))
@@ -684,26 +544,27 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(model.depth === 1)
}
- test("stump with 2 continuous variables for binary classification") {
+ test("Binary classification stump with 2 continuous features") {
val arr = new Array[LabeledPoint](4)
arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
arr(3) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0))))
- val input = sc.parallelize(arr)
+ val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 2)
- val model = DecisionTree.train(input, strategy)
+ val model = DecisionTree.train(rdd, strategy)
validateClassifier(model, arr, 1.0)
assert(model.numNodes === 3)
assert(model.depth === 1)
assert(model.topNode.split.get.feature === 1)
}
- test("stump with categorical variables for multiclass classification, with just enough bins") {
- val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features
+ test("Multiclass classification stump with unordered categorical features," +
+ " with just enough bins") {
+ val maxBins = 2 * (math.pow(2, 3 - 1).toInt - 1) // just enough bins to allow unordered features
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
@@ -711,6 +572,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
assert(strategy.isMulticlassClassification)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(metadata.isUnordered(featureIndex = 0))
+ assert(metadata.isUnordered(featureIndex = 1))
val model = DecisionTree.train(rdd, strategy)
validateClassifier(model, arr, 1.0)
@@ -719,7 +582,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
@@ -733,7 +596,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(gain.rightImpurity === 0)
}
- test("stump with continuous variables for multiclass classification") {
+ test("Multiclass classification stump with continuous features") {
val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
@@ -746,7 +609,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
@@ -759,20 +622,21 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
}
- test("stump with continuous + categorical variables for multiclass classification") {
+ test("Multiclass classification stump with continuous + unordered categorical features") {
val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3))
assert(strategy.isMulticlassClassification)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(metadata.isUnordered(featureIndex = 0))
val model = DecisionTree.train(rdd, strategy)
validateClassifier(model, arr, 0.9)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
@@ -784,17 +648,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplit.threshold < 2020)
}
- test("stump with categorical variables for ordered multiclass classification") {
+ test("Multiclass classification stump with 10-ary (ordered) categorical features") {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
assert(strategy.isMulticlassClassification)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
@@ -805,6 +671,18 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplit.featureType === Categorical)
}
+ test("Multiclass classification tree with 10-ary (ordered) categorical features," +
+ " with just enough bins") {
+ val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
+ numClassesForClassification = 3, maxBins = 10,
+ categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
+ assert(strategy.isMulticlassClassification)
+
+ val model = DecisionTree.train(rdd, strategy)
+ validateClassifier(model, arr, 0.6)
+ }
}
@@ -899,5 +777,4 @@ object DecisionTreeSuite {
arr
}
-
}