aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala878
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala101
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala30
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala18
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala28
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala167
9 files changed, 615 insertions, 630 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 2a3107a13e..6b9a8f72c2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -27,7 +27,7 @@ import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
-import org.apache.spark.mllib.tree.impl.{TimeTracker, TreePoint}
+import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TimeTracker, TreePoint}
import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity}
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
@@ -62,43 +62,38 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
timer.start("init")
val retaggedInput = input.retag(classOf[LabeledPoint])
+ val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy)
logDebug("algo = " + strategy.algo)
// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
timer.start("findSplitsBins")
- val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata)
val numBins = bins(0).length
timer.stop("findSplitsBins")
logDebug("numBins = " + numBins)
+ // Bin feature values (TreePoint representation).
// Cache input RDD for speedup during multiple passes.
- val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins)
+ val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
.persist(StorageLevel.MEMORY_AND_DISK)
+ val numFeatures = metadata.numFeatures
// depth of the decision tree
val maxDepth = strategy.maxDepth
// the max number of nodes possible given the depth of the tree
- val maxNumNodes = math.pow(2, maxDepth + 1).toInt - 1
- // Initialize an array to hold filters applied to points for each node.
- val filters = new Array[List[Filter]](maxNumNodes)
- // The filter at the top node is an empty list.
- filters(0) = List()
+ val maxNumNodes = (2 << maxDepth) - 1
// Initialize an array to hold parent impurity calculations for each node.
val parentImpurities = new Array[Double](maxNumNodes)
// dummy value for top node (updated during first split calculation)
val nodes = new Array[Node](maxNumNodes)
- // num features
- val numFeatures = treeInput.take(1)(0).binnedFeatures.size
// Calculate level for single group construction
// Max memory usage for aggregates
val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024
logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
- val numElementsPerNode = DecisionTree.getElementsPerNode(numFeatures, numBins,
- strategy.numClassesForClassification, strategy.isMulticlassWithCategoricalFeatures,
- strategy.algo)
+ val numElementsPerNode = DecisionTree.getElementsPerNode(metadata, numBins)
logDebug("numElementsPerNode = " + numElementsPerNode)
val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array
@@ -114,9 +109,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
/*
* The main idea here is to perform level-wise training of the decision tree nodes thus
* reducing the passes over the data from l to log2(l) where l is the total number of nodes.
- * Each data sample is checked for validity w.r.t to each node at a given level -- i.e.,
- * the sample is only used for the split calculation at the node if the sampled would have
- * still survived the filters of the parent nodes.
+ * Each data sample is handled by a particular node at that level (or it reaches a leaf
+ * beforehand and is not used in later levels.
*/
var level = 0
@@ -130,22 +124,37 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// Find best split for all nodes at a level.
timer.start("findBestSplits")
val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities,
- strategy, level, filters, splits, bins, maxLevelForSingleGroup, timer)
+ metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
timer.stop("findBestSplits")
+ val levelNodeIndexOffset = (1 << level) - 1
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
+ val nodeIndex = levelNodeIndexOffset + index
+ val isLeftChild = level != 0 && nodeIndex % 2 == 1
+ val parentNodeIndex = if (isLeftChild) { // -1 for root node
+ (nodeIndex - 1) / 2
+ } else {
+ (nodeIndex - 2) / 2
+ }
+ // Extract info for this node (index) at the current level.
timer.start("extractNodeInfo")
- // Extract info for nodes at the current level.
extractNodeInfo(nodeSplitStats, level, index, nodes)
timer.stop("extractNodeInfo")
- timer.start("extractInfoForLowerLevels")
+ if (level != 0) {
+ // Set parent.
+ if (isLeftChild) {
+ nodes(parentNodeIndex).leftNode = Some(nodes(nodeIndex))
+ } else {
+ nodes(parentNodeIndex).rightNode = Some(nodes(nodeIndex))
+ }
+ }
// Extract info for nodes at the next lower level.
- extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
- filters)
+ timer.start("extractInfoForLowerLevels")
+ extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities)
timer.stop("extractInfoForLowerLevels")
logDebug("final best split = " + nodeSplitStats._1)
}
- require(math.pow(2, level) == splitsStatsForLevel.length)
+ require((1 << level) == splitsStatsForLevel.length)
// Check whether all the nodes at the current level at leaves.
val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
logDebug("all leaf = " + allLeaf)
@@ -183,7 +192,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
nodes: Array[Node]): Unit = {
val split = nodeSplitStats._1
val stats = nodeSplitStats._2
- val nodeIndex = math.pow(2, level).toInt - 1 + index
+ val nodeIndex = (1 << level) - 1 + index
val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
logDebug("Node = " + node)
@@ -198,31 +207,21 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
index: Int,
maxDepth: Int,
nodeSplitStats: (Split, InformationGainStats),
- parentImpurities: Array[Double],
- filters: Array[List[Filter]]): Unit = {
- // 0 corresponds to the left child node and 1 corresponds to the right child node.
- var i = 0
- while (i <= 1) {
- // Calculate the index of the node from the node level and the index at the current level.
- val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i
- if (level < maxDepth) {
- val impurity = if (i == 0) {
- nodeSplitStats._2.leftImpurity
- } else {
- nodeSplitStats._2.rightImpurity
- }
- logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity)
- // noting the parent impurities
- parentImpurities(nodeIndex) = impurity
- // noting the parents filters for the child nodes
- val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1)
- filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2)
- for (filter <- filters(nodeIndex)) {
- logDebug("Filter = " + filter)
- }
- }
- i += 1
+ parentImpurities: Array[Double]): Unit = {
+
+ if (level >= maxDepth) {
+ return
}
+
+ val leftNodeIndex = (2 << level) - 1 + 2 * index
+ val leftImpurity = nodeSplitStats._2.leftImpurity
+ logDebug("leftNodeIndex = " + leftNodeIndex + ", impurity = " + leftImpurity)
+ parentImpurities(leftNodeIndex) = leftImpurity
+
+ val rightNodeIndex = leftNodeIndex + 1
+ val rightImpurity = nodeSplitStats._2.rightImpurity
+ logDebug("rightNodeIndex = " + rightNodeIndex + ", impurity = " + rightImpurity)
+ parentImpurities(rightNodeIndex) = rightImpurity
}
}
@@ -434,10 +433,8 @@ object DecisionTree extends Serializable with Logging {
*
* @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
* @param parentImpurities Impurities for all parent nodes for the current level
- * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
- * parameters for constructing the DecisionTree
+ * @param metadata Learning and dataset metadata
* @param level Level of the tree
- * @param filters Filters for all nodes at a given level
* @param splits possible splits for all features
* @param bins possible bins for all features
* @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
@@ -446,9 +443,9 @@ object DecisionTree extends Serializable with Logging {
protected[tree] def findBestSplits(
input: RDD[TreePoint],
parentImpurities: Array[Double],
- strategy: Strategy,
+ metadata: DecisionTreeMetadata,
level: Int,
- filters: Array[List[Filter]],
+ nodes: Array[Node],
splits: Array[Array[Split]],
bins: Array[Array[Bin]],
maxLevelForSingleGroup: Int,
@@ -459,34 +456,32 @@ object DecisionTree extends Serializable with Logging {
// the nodes are divided into multiple groups at each level with the number of groups
// increasing exponentially per level. For example, if maxLevelForSingleGroup is 10,
// numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
- val numGroups = math.pow(2, level - maxLevelForSingleGroup).toInt
+ val numGroups = 1 << level - maxLevelForSingleGroup
logDebug("numGroups = " + numGroups)
var bestSplits = new Array[(Split, InformationGainStats)](0)
// Iterate over each group of nodes at a level.
var groupIndex = 0
while (groupIndex < numGroups) {
- val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level,
- filters, splits, bins, timer, numGroups, groupIndex)
+ val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, metadata, level,
+ nodes, splits, bins, timer, numGroups, groupIndex)
bestSplits = Array.concat(bestSplits, bestSplitsForGroup)
groupIndex += 1
}
bestSplits
} else {
- findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins, timer)
+ findBestSplitsPerGroup(input, parentImpurities, metadata, level, nodes, splits, bins, timer)
}
}
- /**
+ /**
* Returns an array of optimal splits for a group of nodes at a given level
*
* @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
* @param parentImpurities Impurities for all parent nodes for the current level
- * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
- * parameters for constructing the DecisionTree
+ * @param metadata Learning and dataset metadata
* @param level Level of the tree
- * @param filters Filters for all nodes at a given level
* @param splits possible splits for all features
- * @param bins possible bins for all features
+ * @param bins possible bins for all features, indexed as (numFeatures)(numBins)
* @param numGroups total number of node groups at the current level. Default value is set to 1.
* @param groupIndex index of the node group being processed. Default value is set to 0.
* @return array of splits with best splits for all nodes at a given level.
@@ -494,9 +489,9 @@ object DecisionTree extends Serializable with Logging {
private def findBestSplitsPerGroup(
input: RDD[TreePoint],
parentImpurities: Array[Double],
- strategy: Strategy,
+ metadata: DecisionTreeMetadata,
level: Int,
- filters: Array[List[Filter]],
+ nodes: Array[Node],
splits: Array[Array[Split]],
bins: Array[Array[Bin]],
timer: TimeTracker,
@@ -515,7 +510,7 @@ object DecisionTree extends Serializable with Logging {
* We use a bin-wise best split computation strategy instead of a straightforward best split
* computation strategy. Instead of analyzing each sample for contribution to the left/right
* child node impurity of every split, we first categorize each feature of a sample into a
- * bin. Each bin is an interval between a low and high split. Since each splits, and thus bin,
+ * bin. Each bin is an interval between a low and high split. Since each split, and thus bin,
* is ordered (read ordering for categorical variables in the findSplitsBins method),
* we exploit this structure to calculate aggregates for bins and then use these aggregates
* to calculate information gain for each split.
@@ -531,160 +526,124 @@ object DecisionTree extends Serializable with Logging {
// numNodes: Number of nodes in this (level of tree, group),
// where nodes at deeper (larger) levels may be divided into groups.
- val numNodes = math.pow(2, level).toInt / numGroups
+ val numNodes = (1 << level) / numGroups
logDebug("numNodes = " + numNodes)
// Find the number of features by looking at the first sample.
- val numFeatures = input.first().binnedFeatures.size
+ val numFeatures = metadata.numFeatures
logDebug("numFeatures = " + numFeatures)
// numBins: Number of bins = 1 + number of possible splits
val numBins = bins(0).length
logDebug("numBins = " + numBins)
- val numClasses = strategy.numClassesForClassification
+ val numClasses = metadata.numClasses
logDebug("numClasses = " + numClasses)
- val isMulticlassClassification = strategy.isMulticlassClassification
- logDebug("isMulticlassClassification = " + isMulticlassClassification)
+ val isMulticlass = metadata.isMulticlass
+ logDebug("isMulticlass = " + isMulticlass)
- val isMulticlassClassificationWithCategoricalFeatures
- = strategy.isMulticlassWithCategoricalFeatures
- logDebug("isMultiClassWithCategoricalFeatures = " +
- isMulticlassClassificationWithCategoricalFeatures)
+ val isMulticlassWithCategoricalFeatures = metadata.isMulticlassWithCategoricalFeatures
+ logDebug("isMultiClassWithCategoricalFeatures = " + isMulticlassWithCategoricalFeatures)
// shift when more than one group is used at deep tree level
val groupShift = numNodes * groupIndex
- /** Find the filters used before reaching the current code. */
- def findParentFilters(nodeIndex: Int): List[Filter] = {
- if (level == 0) {
- List[Filter]()
- } else {
- val nodeFilterIndex = math.pow(2, level).toInt - 1 + nodeIndex + groupShift
- filters(nodeFilterIndex)
- }
- }
-
/**
- * Find whether the sample is valid input for the current node, i.e., whether it passes through
- * all the filters for the current node.
+ * Get the node index corresponding to this data point.
+ * This function mimics prediction, passing an example from the root node down to a node
+ * at the current level being trained; that node's index is returned.
+ *
+ * @return Leaf index if the data point reaches a leaf.
+ * Otherwise, last node reachable in tree matching this example.
*/
- def isSampleValid(parentFilters: List[Filter], treePoint: TreePoint): Boolean = {
- // leaf
- if ((level > 0) && (parentFilters.length == 0)) {
- return false
- }
-
- // Apply each filter and check sample validity. Return false when invalid condition found.
- parentFilters.foreach { filter =>
- val featureIndex = filter.split.feature
- val comparison = filter.comparison
- val isFeatureContinuous = filter.split.featureType == Continuous
- if (isFeatureContinuous) {
- val binId = treePoint.binnedFeatures(featureIndex)
- val bin = bins(featureIndex)(binId)
- val featureValue = bin.highSplit.threshold
- val threshold = filter.split.threshold
- comparison match {
- case -1 => if (featureValue > threshold) return false
- case 1 => if (featureValue <= threshold) return false
+ def predictNodeIndex(node: Node, binnedFeatures: Array[Int]): Int = {
+ if (node.isLeaf) {
+ node.id
+ } else {
+ val featureIndex = node.split.get.feature
+ val splitLeft = node.split.get.featureType match {
+ case Continuous => {
+ val binIndex = binnedFeatures(featureIndex)
+ val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
+ // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold]
+ // We do not need to check lowSplit since bins are separated by splits.
+ featureValueUpperBound <= node.split.get.threshold
}
- } else {
- val numFeatureCategories = strategy.categoricalFeaturesInfo(featureIndex)
- val isSpaceSufficientForAllCategoricalSplits =
- numBins > math.pow(2, numFeatureCategories.toInt - 1) - 1
- val isUnorderedFeature =
- isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
- val featureValue = if (isUnorderedFeature) {
- treePoint.binnedFeatures(featureIndex)
+ case Categorical => {
+ val featureValue = if (metadata.isUnordered(featureIndex)) {
+ binnedFeatures(featureIndex)
+ } else {
+ val binIndex = binnedFeatures(featureIndex)
+ bins(featureIndex)(binIndex).category
+ }
+ node.split.get.categories.contains(featureValue)
+ }
+ case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.")
+ }
+ if (node.leftNode.isEmpty || node.rightNode.isEmpty) {
+ // Return index from next layer of nodes to train
+ if (splitLeft) {
+ node.id * 2 + 1 // left
} else {
- val binId = treePoint.binnedFeatures(featureIndex)
- bins(featureIndex)(binId).category
+ node.id * 2 + 2 // right
}
- val containsFeature = filter.split.categories.contains(featureValue)
- comparison match {
- case -1 => if (!containsFeature) return false
- case 1 => if (containsFeature) return false
+ } else {
+ if (splitLeft) {
+ predictNodeIndex(node.leftNode.get, binnedFeatures)
+ } else {
+ predictNodeIndex(node.rightNode.get, binnedFeatures)
}
}
}
+ }
- // Return true when the sample is valid for all filters.
- true
+ def nodeIndexToLevel(idx: Int): Int = {
+ if (idx == 0) {
+ 0
+ } else {
+ math.floor(math.log(idx) / math.log(2)).toInt
+ }
}
+ // Used for treePointToNodeIndex
+ val levelOffset = (1 << level) - 1
+
/**
- * Finds bins for all nodes (and all features) at a given level.
- * For l nodes, k features the storage is as follows:
- * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk,
- * where b_ij is an integer between 0 and numBins - 1 for regressions and binary
- * classification and the categorical feature value in multiclass classification.
- * Invalid sample is denoted by noting bin for feature 1 as -1.
- *
- * For unordered features, the "bin index" returned is actually the feature value (category).
- *
- * @return Array of size 1 + numFeatures * numNodes, where
- * arr(0) = label for labeledPoint, and
- * arr(1 + numFeatures * nodeIndex + featureIndex) =
- * bin index for this labeledPoint
- * (or InvalidBinIndex if labeledPoint is not handled by this node)
+ * Find the node index for the given example.
+ * Nodes are indexed from 0 at the start of this (level, group).
+ * If the example does not reach this level, returns a value < 0.
*/
- def findBinsForLevel(treePoint: TreePoint): Array[Double] = {
- // Calculate bin index and label per feature per node.
- val arr = new Array[Double](1 + (numFeatures * numNodes))
- // First element of the array is the label of the instance.
- arr(0) = treePoint.label
- // Iterate over nodes.
- var nodeIndex = 0
- while (nodeIndex < numNodes) {
- val parentFilters = findParentFilters(nodeIndex)
- // Find out whether the sample qualifies for the particular node.
- val sampleValid = isSampleValid(parentFilters, treePoint)
- val shift = 1 + numFeatures * nodeIndex
- if (!sampleValid) {
- // Mark one bin as -1 is sufficient.
- arr(shift) = InvalidBinIndex
- } else {
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- arr(shift + featureIndex) = treePoint.binnedFeatures(featureIndex)
- featureIndex += 1
- }
- }
- nodeIndex += 1
+ def treePointToNodeIndex(treePoint: TreePoint): Int = {
+ if (level == 0) {
+ 0
+ } else {
+ val globalNodeIndex = predictNodeIndex(nodes(0), treePoint.binnedFeatures)
+ // Get index for this (level, group).
+ globalNodeIndex - levelOffset - groupShift
}
- arr
}
- // Find feature bins for all nodes at a level.
- timer.start("aggregation")
- val binMappedRDD = input.map(x => findBinsForLevel(x))
-
/**
* Increment aggregate in location for (node, feature, bin, label).
*
- * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label.
- * Array of size 1 + (numFeatures * numNodes).
+ * @param treePoint Data point being aggregated.
* @param agg Array storing aggregate calculation, of size:
* numClasses * numBins * numFeatures * numNodes.
* Indexed by (node, feature, bin, label) where label is the least significant bit.
+ * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
*/
def updateBinForOrderedFeature(
- arr: Array[Double],
+ treePoint: TreePoint,
agg: Array[Double],
nodeIndex: Int,
- label: Double,
featureIndex: Int): Unit = {
- // Find the bin index for this feature.
- val arrShift = 1 + numFeatures * nodeIndex
- val arrIndex = arrShift + featureIndex
// Update the left or right count for one bin.
val aggIndex =
numClasses * numBins * numFeatures * nodeIndex +
numClasses * numBins * featureIndex +
- numClasses * arr(arrIndex).toInt +
- label.toInt
+ numClasses * treePoint.binnedFeatures(featureIndex) +
+ treePoint.label.toInt
agg(aggIndex) += 1
}
@@ -693,8 +652,8 @@ object DecisionTree extends Serializable with Logging {
* where [bins] ranges over all bins.
* Updates left or right side of aggregate depending on split.
*
- * @param arr arr(0) = label.
- * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category)
+ * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
+ * @param treePoint Data point being aggregated.
* @param agg Indexed by (left/right, node, feature, bin, label)
* where label is the least significant bit.
* The left/right specifier is a 0/1 index indicating left/right child info.
@@ -703,21 +662,18 @@ object DecisionTree extends Serializable with Logging {
def updateBinForUnorderedFeature(
nodeIndex: Int,
featureIndex: Int,
- arr: Array[Double],
- label: Double,
+ treePoint: TreePoint,
agg: Array[Double],
rightChildShift: Int): Unit = {
- // Find the bin index for this feature.
- val arrIndex = 1 + numFeatures * nodeIndex + featureIndex
- val featureValue = arr(arrIndex).toInt
+ val featureValue = treePoint.binnedFeatures(featureIndex)
// Update the left or right count for one bin.
val aggShift =
numClasses * numBins * numFeatures * nodeIndex +
numClasses * numBins * featureIndex +
- label.toInt
+ treePoint.label.toInt
// Find all matching bins and increment their values
- val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
- val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
+ val featureCategories = metadata.featureArity(featureIndex)
+ val numCategoricalBins = (1 << featureCategories - 1) - 1
var binIndex = 0
while (binIndex < numCategoricalBins) {
val aggIndex = aggShift + binIndex * numClasses
@@ -733,30 +689,21 @@ object DecisionTree extends Serializable with Logging {
/**
* Helper for binSeqOp.
*
- * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label.
- * Array of size 1 + (numFeatures * numNodes).
* @param agg Array storing aggregate calculation, of size:
* numClasses * numBins * numFeatures * numNodes.
* Indexed by (node, feature, bin, label) where label is the least significant bit.
+ * @param treePoint Data point being aggregated.
+ * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
*/
- def binaryOrNotCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = {
- // Iterate over all nodes.
- var nodeIndex = 0
- while (nodeIndex < numNodes) {
- // Check whether the instance was valid for this nodeIndex.
- val validSignalIndex = 1 + numFeatures * nodeIndex
- val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
- if (isSampleValidForNode) {
- // actual class label
- val label = arr(0)
- // Iterate over all features.
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
- featureIndex += 1
- }
- }
- nodeIndex += 1
+ def binaryOrNotCategoricalBinSeqOp(
+ agg: Array[Double],
+ treePoint: TreePoint,
+ nodeIndex: Int): Unit = {
+ // Iterate over all features.
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex)
+ featureIndex += 1
}
}
@@ -765,49 +712,28 @@ object DecisionTree extends Serializable with Logging {
/**
* Helper for binSeqOp.
*
- * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label.
- * Array of size 1 + (numFeatures * numNodes).
- * For ordered features,
- * arr(1 + featureIndex + nodeIndex * numFeatures) = bin index.
- * For unordered features,
- * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category).
* @param agg Array storing aggregate calculation.
* For ordered features, this is of size:
* numClasses * numBins * numFeatures * numNodes.
* For unordered features, this is of size:
* 2 * numClasses * numBins * numFeatures * numNodes.
+ * @param treePoint Data point being aggregated.
+ * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
*/
- def multiclassWithCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = {
- // Iterate over all nodes.
- var nodeIndex = 0
- while (nodeIndex < numNodes) {
- // Check whether the instance was valid for this nodeIndex.
- val validSignalIndex = 1 + numFeatures * nodeIndex
- val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
- if (isSampleValidForNode) {
- // actual class label
- val label = arr(0)
- // Iterate over all features.
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
- if (isFeatureContinuous) {
- updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
- } else {
- val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
- val isSpaceSufficientForAllCategoricalSplits
- = numBins > math.pow(2, featureCategories.toInt - 1) - 1
- if (isSpaceSufficientForAllCategoricalSplits) {
- updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg,
- rightChildShift)
- } else {
- updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
- }
- }
- featureIndex += 1
- }
+ def multiclassWithCategoricalBinSeqOp(
+ agg: Array[Double],
+ treePoint: TreePoint,
+ nodeIndex: Int): Unit = {
+ val label = treePoint.label
+ // Iterate over all features.
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ if (metadata.isUnordered(featureIndex)) {
+ updateBinForUnorderedFeature(nodeIndex, featureIndex, treePoint, agg, rightChildShift)
+ } else {
+ updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex)
}
- nodeIndex += 1
+ featureIndex += 1
}
}
@@ -818,36 +744,25 @@ object DecisionTree extends Serializable with Logging {
*
* @param agg Array storing aggregate calculation, updated by this function.
* Size: 3 * numBins * numFeatures * numNodes
- * @param arr Bin mapping from findBinsForLevel.
- * Array of size 1 + (numFeatures * numNodes).
+ * @param treePoint Data point being aggregated.
+ * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
* @return agg
*/
- def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = {
- // Iterate over all nodes.
- var nodeIndex = 0
- while (nodeIndex < numNodes) {
- // Check whether the instance was valid for this nodeIndex.
- val validSignalIndex = 1 + numFeatures * nodeIndex
- val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
- if (isSampleValidForNode) {
- // actual class label
- val label = arr(0)
- // Iterate over all features.
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- // Find the bin index for this feature.
- val arrShift = 1 + numFeatures * nodeIndex
- val arrIndex = arrShift + featureIndex
- // Update count, sum, and sum^2 for one bin.
- val aggShift = 3 * numBins * numFeatures * nodeIndex
- val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3
- agg(aggIndex) = agg(aggIndex) + 1
- agg(aggIndex + 1) = agg(aggIndex + 1) + label
- agg(aggIndex + 2) = agg(aggIndex + 2) + label * label
- featureIndex += 1
- }
- }
- nodeIndex += 1
+ def regressionBinSeqOp(agg: Array[Double], treePoint: TreePoint, nodeIndex: Int): Unit = {
+ val label = treePoint.label
+ // Iterate over all features.
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ // Update count, sum, and sum^2 for one bin.
+ val binIndex = treePoint.binnedFeatures(featureIndex)
+ val aggIndex =
+ 3 * numBins * numFeatures * nodeIndex +
+ 3 * numBins * featureIndex +
+ 3 * binIndex
+ agg(aggIndex) += 1
+ agg(aggIndex + 1) += label
+ agg(aggIndex + 2) += label * label
+ featureIndex += 1
}
}
@@ -866,26 +781,30 @@ object DecisionTree extends Serializable with Logging {
* 2 * numClasses * numBins * numFeatures * numNodes for unordered features.
* Size for regression:
* 3 * numBins * numFeatures * numNodes.
- * @param arr Bin mapping from findBinsForLevel.
- * Array of size 1 + (numFeatures * numNodes).
+ * @param treePoint Data point being aggregated.
* @return agg
*/
- def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = {
- strategy.algo match {
- case Classification =>
- if(isMulticlassClassificationWithCategoricalFeatures) {
- multiclassWithCategoricalBinSeqOp(arr, agg)
+ def binSeqOp(agg: Array[Double], treePoint: TreePoint): Array[Double] = {
+ val nodeIndex = treePointToNodeIndex(treePoint)
+ // If the example does not reach this level, then nodeIndex < 0.
+ // If the example reaches this level but is handled in a different group,
+ // then either nodeIndex < 0 (previous group) or nodeIndex >= numNodes (later group).
+ if (nodeIndex >= 0 && nodeIndex < numNodes) {
+ if (metadata.isClassification) {
+ if (isMulticlassWithCategoricalFeatures) {
+ multiclassWithCategoricalBinSeqOp(agg, treePoint, nodeIndex)
} else {
- binaryOrNotCategoricalBinSeqOp(arr, agg)
+ binaryOrNotCategoricalBinSeqOp(agg, treePoint, nodeIndex)
}
- case Regression => regressionBinSeqOp(arr, agg)
+ } else {
+ regressionBinSeqOp(agg, treePoint, nodeIndex)
+ }
}
agg
}
// Calculate bin aggregate length for classification or regression.
- val binAggregateLength = numNodes * getElementsPerNode(numFeatures, numBins, numClasses,
- isMulticlassClassificationWithCategoricalFeatures, strategy.algo)
+ val binAggregateLength = numNodes * getElementsPerNode(metadata, numBins)
logDebug("binAggregateLength = " + binAggregateLength)
/**
@@ -905,144 +824,134 @@ object DecisionTree extends Serializable with Logging {
}
// Calculate bin aggregates.
+ timer.start("aggregation")
val binAggregates = {
- binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp)
+ input.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp)
}
timer.stop("aggregation")
logDebug("binAggregates.length = " + binAggregates.length)
/**
- * Calculates the information gain for all splits based upon left/right split aggregates.
- * @param leftNodeAgg left node aggregates
- * @param featureIndex feature index
- * @param splitIndex split index
- * @param rightNodeAgg right node aggregate
+ * Calculate the information gain for a given (feature, split) based upon left/right aggregates.
+ * @param leftNodeAgg left node aggregates for this (feature, split)
+ * @param rightNodeAgg right node aggregate for this (feature, split)
* @param topImpurity impurity of the parent node
* @return information gain and statistics for all splits
*/
def calculateGainForSplit(
- leftNodeAgg: Array[Array[Array[Double]]],
- featureIndex: Int,
- splitIndex: Int,
- rightNodeAgg: Array[Array[Array[Double]]],
+ leftNodeAgg: Array[Double],
+ rightNodeAgg: Array[Double],
topImpurity: Double): InformationGainStats = {
- strategy.algo match {
- case Classification =>
- val leftCounts: Array[Double] = leftNodeAgg(featureIndex)(splitIndex)
- val rightCounts: Array[Double] = rightNodeAgg(featureIndex)(splitIndex)
- val leftTotalCount = leftCounts.sum
- val rightTotalCount = rightCounts.sum
-
- val impurity = {
- if (level > 0) {
- topImpurity
- } else {
- // Calculate impurity for root node.
- val rootNodeCounts = new Array[Double](numClasses)
- var classIndex = 0
- while (classIndex < numClasses) {
- rootNodeCounts(classIndex) = leftCounts(classIndex) + rightCounts(classIndex)
- classIndex += 1
- }
- strategy.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount)
- }
- }
+ if (metadata.isClassification) {
+ val leftTotalCount = leftNodeAgg.sum
+ val rightTotalCount = rightNodeAgg.sum
- val totalCount = leftTotalCount + rightTotalCount
- if (totalCount == 0) {
- // Return arbitrary prediction.
- return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
+ val impurity = {
+ if (level > 0) {
+ topImpurity
+ } else {
+ // Calculate impurity for root node.
+ val rootNodeCounts = new Array[Double](numClasses)
+ var classIndex = 0
+ while (classIndex < numClasses) {
+ rootNodeCounts(classIndex) = leftNodeAgg(classIndex) + rightNodeAgg(classIndex)
+ classIndex += 1
+ }
+ metadata.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount)
}
+ }
- // Sum of count for each label
- val leftRightCounts: Array[Double] =
- leftCounts.zip(rightCounts).map { case (leftCount, rightCount) =>
- leftCount + rightCount
- }
+ val totalCount = leftTotalCount + rightTotalCount
+ if (totalCount == 0) {
+ // Return arbitrary prediction.
+ return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
+ }
- def indexOfLargestArrayElement(array: Array[Double]): Int = {
- val result = array.foldLeft(-1, Double.MinValue, 0) {
- case ((maxIndex, maxValue, currentIndex), currentValue) =>
- if (currentValue > maxValue) {
- (currentIndex, currentValue, currentIndex + 1)
- } else {
- (maxIndex, maxValue, currentIndex + 1)
- }
- }
- if (result._1 < 0) {
- throw new RuntimeException("DecisionTree internal error:" +
- " calculateGainForSplit failed in indexOfLargestArrayElement")
- }
- result._1
+ // Sum of count for each label
+ val leftrightNodeAgg: Array[Double] =
+ leftNodeAgg.zip(rightNodeAgg).map { case (leftCount, rightCount) =>
+ leftCount + rightCount
}
- val predict = indexOfLargestArrayElement(leftRightCounts)
- val prob = leftRightCounts(predict) / totalCount
-
- val leftImpurity = if (leftTotalCount == 0) {
- topImpurity
- } else {
- strategy.impurity.calculate(leftCounts, leftTotalCount)
+ def indexOfLargestArrayElement(array: Array[Double]): Int = {
+ val result = array.foldLeft(-1, Double.MinValue, 0) {
+ case ((maxIndex, maxValue, currentIndex), currentValue) =>
+ if (currentValue > maxValue) {
+ (currentIndex, currentValue, currentIndex + 1)
+ } else {
+ (maxIndex, maxValue, currentIndex + 1)
+ }
}
- val rightImpurity = if (rightTotalCount == 0) {
- topImpurity
- } else {
- strategy.impurity.calculate(rightCounts, rightTotalCount)
+ if (result._1 < 0) {
+ throw new RuntimeException("DecisionTree internal error:" +
+ " calculateGainForSplit failed in indexOfLargestArrayElement")
}
+ result._1
+ }
- val leftWeight = leftTotalCount / totalCount
- val rightWeight = rightTotalCount / totalCount
+ val predict = indexOfLargestArrayElement(leftrightNodeAgg)
+ val prob = leftrightNodeAgg(predict) / totalCount
- val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+ val leftImpurity = if (leftTotalCount == 0) {
+ topImpurity
+ } else {
+ metadata.impurity.calculate(leftNodeAgg, leftTotalCount)
+ }
+ val rightImpurity = if (rightTotalCount == 0) {
+ topImpurity
+ } else {
+ metadata.impurity.calculate(rightNodeAgg, rightTotalCount)
+ }
- new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
+ val leftWeight = leftTotalCount / totalCount
+ val rightWeight = rightTotalCount / totalCount
- case Regression =>
- val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0)
- val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1)
- val leftSumSquares = leftNodeAgg(featureIndex)(splitIndex)(2)
+ val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
- val rightCount = rightNodeAgg(featureIndex)(splitIndex)(0)
- val rightSum = rightNodeAgg(featureIndex)(splitIndex)(1)
- val rightSumSquares = rightNodeAgg(featureIndex)(splitIndex)(2)
+ new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
- val impurity = {
- if (level > 0) {
- topImpurity
- } else {
- // Calculate impurity for root node.
- val count = leftCount + rightCount
- val sum = leftSum + rightSum
- val sumSquares = leftSumSquares + rightSumSquares
- strategy.impurity.calculate(count, sum, sumSquares)
- }
- }
+ } else {
+ // Regression
- if (leftCount == 0) {
- return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,
- rightSum / rightCount)
- }
- if (rightCount == 0) {
- return new InformationGainStats(0, topImpurity ,topImpurity,
- Double.MinValue, leftSum / leftCount)
+ val leftCount = leftNodeAgg(0)
+ val leftSum = leftNodeAgg(1)
+ val leftSumSquares = leftNodeAgg(2)
+
+ val rightCount = rightNodeAgg(0)
+ val rightSum = rightNodeAgg(1)
+ val rightSumSquares = rightNodeAgg(2)
+
+ val impurity = {
+ if (level > 0) {
+ topImpurity
+ } else {
+ // Calculate impurity for root node.
+ val count = leftCount + rightCount
+ val sum = leftSum + rightSum
+ val sumSquares = leftSumSquares + rightSumSquares
+ metadata.impurity.calculate(count, sum, sumSquares)
}
+ }
+
+ if (leftCount == 0) {
+ return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,
+ rightSum / rightCount)
+ }
+ if (rightCount == 0) {
+ return new InformationGainStats(0, topImpurity, topImpurity,
+ Double.MinValue, leftSum / leftCount)
+ }
- val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares)
- val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares)
+ val leftImpurity = metadata.impurity.calculate(leftCount, leftSum, leftSumSquares)
+ val rightImpurity = metadata.impurity.calculate(rightCount, rightSum, rightSumSquares)
- val leftWeight = leftCount.toDouble / (leftCount + rightCount)
- val rightWeight = rightCount.toDouble / (leftCount + rightCount)
+ val leftWeight = leftCount.toDouble / (leftCount + rightCount)
+ val rightWeight = rightCount.toDouble / (leftCount + rightCount)
- val gain = {
- if (level > 0) {
- impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
- } else {
- impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
- }
- }
+ val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
- val predict = (leftSum + rightSum) / (leftCount + rightCount)
- new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
+ val predict = (leftSum + rightSum) / (leftCount + rightCount)
+ new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
}
}
@@ -1065,6 +974,19 @@ object DecisionTree extends Serializable with Logging {
binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = {
+ /**
+ * The input binData is indexed as (feature, bin, class).
+ * This computes cumulative sums over splits.
+ * Each (feature, class) pair is handled separately.
+ * Note: numSplits = numBins - 1.
+ * @param leftNodeAgg Each (feature, class) slice is an array over splits.
+ * Element i (i = 0, ..., numSplits - 2) is set to be
+ * the cumulative sum (from left) over binData for bins 0, ..., i.
+ * @param rightNodeAgg Each (feature, class) slice is an array over splits.
+ * Element i (i = 1, ..., numSplits - 1) is set to be
+ * the cumulative sum (from right) over binData for bins
+ * numBins - 1, ..., numBins - 1 - i.
+ */
def findAggForOrderedFeatureClassification(
leftNodeAgg: Array[Array[Array[Double]]],
rightNodeAgg: Array[Array[Array[Double]]],
@@ -1169,45 +1091,32 @@ object DecisionTree extends Serializable with Logging {
}
}
- strategy.algo match {
- case Classification =>
- // Initialize left and right split aggregates.
- val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
- val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- if (isMulticlassClassificationWithCategoricalFeatures) {
- val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
- if (isFeatureContinuous) {
- findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
- } else {
- val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
- val isSpaceSufficientForAllCategoricalSplits
- = numBins > math.pow(2, featureCategories.toInt - 1) - 1
- if (isSpaceSufficientForAllCategoricalSplits) {
- findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
- } else {
- findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
- }
- }
- } else {
- findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
- }
- featureIndex += 1
- }
-
- (leftNodeAgg, rightNodeAgg)
- case Regression =>
- // Initialize left and right split aggregates.
- val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
- val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
- // Iterate over all features.
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex)
- featureIndex += 1
+ if (metadata.isClassification) {
+ // Initialize left and right split aggregates.
+ val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
+ val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ if (metadata.isUnordered(featureIndex)) {
+ findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
+ } else {
+ findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
}
- (leftNodeAgg, rightNodeAgg)
+ featureIndex += 1
+ }
+ (leftNodeAgg, rightNodeAgg)
+ } else {
+ // Regression
+ // Initialize left and right split aggregates.
+ val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
+ val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
+ // Iterate over all features.
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex)
+ featureIndex += 1
+ }
+ (leftNodeAgg, rightNodeAgg)
}
}
@@ -1225,8 +1134,9 @@ object DecisionTree extends Serializable with Logging {
val numSplitsForFeature = getNumSplitsForFeature(featureIndex)
var splitIndex = 0
while (splitIndex < numSplitsForFeature) {
- gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex,
- splitIndex, rightNodeAgg, nodeImpurity)
+ gains(featureIndex)(splitIndex) =
+ calculateGainForSplit(leftNodeAgg(featureIndex)(splitIndex),
+ rightNodeAgg(featureIndex)(splitIndex), nodeImpurity)
splitIndex += 1
}
featureIndex += 1
@@ -1238,18 +1148,14 @@ object DecisionTree extends Serializable with Logging {
* Get the number of splits for a feature.
*/
def getNumSplitsForFeature(featureIndex: Int): Int = {
- val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
- if (isFeatureContinuous) {
+ if (metadata.isContinuous(featureIndex)) {
numBins - 1
} else {
// Categorical feature
- val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
- val isSpaceSufficientForAllCategoricalSplits =
- numBins > math.pow(2, featureCategories.toInt - 1) - 1
- if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
- math.pow(2.0, featureCategories - 1).toInt - 1
+ val featureCategories = metadata.featureArity(featureIndex)
+ if (metadata.isUnordered(featureIndex)) {
+ (1 << featureCategories - 1) - 1
} else {
- // Ordered features
featureCategories
}
}
@@ -1308,29 +1214,29 @@ object DecisionTree extends Serializable with Logging {
* Get bin data for one node.
*/
def getBinDataForNode(node: Int): Array[Double] = {
- strategy.algo match {
- case Classification =>
- if (isMulticlassClassificationWithCategoricalFeatures) {
- val shift = numClasses * node * numBins * numFeatures
- val rightChildShift = numClasses * numBins * numFeatures * numNodes
- val binsForNode = {
- val leftChildData
- = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
- val rightChildData
- = binAggregates.slice(rightChildShift + shift,
- rightChildShift + shift + numClasses * numBins * numFeatures)
- leftChildData ++ rightChildData
- }
- binsForNode
- } else {
- val shift = numClasses * node * numBins * numFeatures
- val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
- binsForNode
+ if (metadata.isClassification) {
+ if (isMulticlassWithCategoricalFeatures) {
+ val shift = numClasses * node * numBins * numFeatures
+ val rightChildShift = numClasses * numBins * numFeatures * numNodes
+ val binsForNode = {
+ val leftChildData
+ = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
+ val rightChildData
+ = binAggregates.slice(rightChildShift + shift,
+ rightChildShift + shift + numClasses * numBins * numFeatures)
+ leftChildData ++ rightChildData
}
- case Regression =>
- val shift = 3 * node * numBins * numFeatures
- val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures)
binsForNode
+ } else {
+ val shift = numClasses * node * numBins * numFeatures
+ val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
+ binsForNode
+ }
+ } else {
+ // Regression
+ val shift = 3 * node * numBins * numFeatures
+ val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures)
+ binsForNode
}
}
@@ -1340,7 +1246,7 @@ object DecisionTree extends Serializable with Logging {
// Iterating over all nodes at this level
var node = 0
while (node < numNodes) {
- val nodeImpurityIndex = math.pow(2, level).toInt - 1 + node + groupShift
+ val nodeImpurityIndex = (1 << level) - 1 + node + groupShift
val binsForNode: Array[Double] = getBinDataForNode(node)
logDebug("nodeImpurityIndex = " + nodeImpurityIndex)
val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
@@ -1358,20 +1264,15 @@ object DecisionTree extends Serializable with Logging {
*
* @param numBins Number of bins = 1 + number of possible splits.
*/
- private def getElementsPerNode(
- numFeatures: Int,
- numBins: Int,
- numClasses: Int,
- isMulticlassClassificationWithCategoricalFeatures: Boolean,
- algo: Algo): Int = {
- algo match {
- case Classification =>
- if (isMulticlassClassificationWithCategoricalFeatures) {
- 2 * numClasses * numBins * numFeatures
- } else {
- numClasses * numBins * numFeatures
- }
- case Regression => 3 * numBins * numFeatures
+ private def getElementsPerNode(metadata: DecisionTreeMetadata, numBins: Int): Int = {
+ if (metadata.isClassification) {
+ if (metadata.isMulticlassWithCategoricalFeatures) {
+ 2 * metadata.numClasses * numBins * metadata.numFeatures
+ } else {
+ metadata.numClasses * numBins * metadata.numFeatures
+ }
+ } else {
+ 3 * numBins * metadata.numFeatures
}
}
@@ -1390,16 +1291,15 @@ object DecisionTree extends Serializable with Logging {
* For multiclass classification with a low-arity feature
* (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
* the feature is split based on subsets of categories.
- * There are math.pow(2, maxFeatureValue - 1) - 1 splits.
+ * There are (1 << maxFeatureValue - 1) - 1 splits.
* (b) "ordered features"
* For regression and binary classification,
* and for multiclass classification with a high-arity feature,
* there is one bin per category.
*
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
- * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
- * parameters for construction the DecisionTree
- * @return A tuple of (splits,bins).
+ * @param metadata Learning and dataset metadata
+ * @return A tuple of (splits, bins).
* Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
* of size (numFeatures, numBins - 1).
* Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]]
@@ -1407,19 +1307,18 @@ object DecisionTree extends Serializable with Logging {
*/
protected[tree] def findSplitsBins(
input: RDD[LabeledPoint],
- strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = {
+ metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = {
val count = input.count()
// Find the number of features by looking at the first sample
val numFeatures = input.take(1)(0).features.size
- val maxBins = strategy.maxBins
+ val maxBins = metadata.maxBins
val numBins = if (maxBins <= count) maxBins else count.toInt
logDebug("numBins = " + numBins)
- val isMulticlassClassification = strategy.isMulticlassClassification
- logDebug("isMulticlassClassification = " + isMulticlassClassification)
-
+ val isMulticlass = metadata.isMulticlass
+ logDebug("isMulticlass = " + isMulticlass)
/*
* Ensure numBins is always greater than the categories. For multiclass classification,
@@ -1431,13 +1330,12 @@ object DecisionTree extends Serializable with Logging {
* by the number of training examples.
* TODO: Allow this case, where we simply will know nothing about some categories.
*/
- if (strategy.categoricalFeaturesInfo.size > 0) {
- val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
+ if (metadata.featureArity.size > 0) {
+ val maxCategoriesForFeatures = metadata.featureArity.maxBy(_._2)._2
require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " +
"in categorical features")
}
-
// Calculate the number of sample for approximate quantile calculation.
val requiredSamples = numBins*numBins
val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0
@@ -1451,7 +1349,7 @@ object DecisionTree extends Serializable with Logging {
val stride: Double = numSamples.toDouble / numBins
logDebug("stride = " + stride)
- strategy.quantileCalculationStrategy match {
+ metadata.quantileStrategy match {
case Sort =>
val splits = Array.ofDim[Split](numFeatures, numBins - 1)
val bins = Array.ofDim[Bin](numFeatures, numBins)
@@ -1462,7 +1360,7 @@ object DecisionTree extends Serializable with Logging {
var featureIndex = 0
while (featureIndex < numFeatures) {
// Check whether the feature is continuous.
- val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+ val isFeatureContinuous = metadata.isContinuous(featureIndex)
if (isFeatureContinuous) {
val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
val stride: Double = numSamples.toDouble / numBins
@@ -1475,18 +1373,14 @@ object DecisionTree extends Serializable with Logging {
splits(featureIndex)(index) = split
}
} else { // Categorical feature
- val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
- val isSpaceSufficientForAllCategoricalSplits
- = numBins > math.pow(2, featureCategories.toInt - 1) - 1
+ val featureCategories = metadata.featureArity(featureIndex)
// Use different bin/split calculation strategy for categorical features in multiclass
// classification that satisfy the space constraint.
- val isUnorderedFeature =
- isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
- if (isUnorderedFeature) {
+ if (metadata.isUnordered(featureIndex)) {
// 2^(maxFeatureValue- 1) - 1 combinations
var index = 0
- while (index < math.pow(2.0, featureCategories - 1).toInt - 1) {
+ while (index < (1 << featureCategories - 1) - 1) {
val categories: List[Double]
= extractMultiClassCategories(index + 1, featureCategories)
splits(featureIndex)(index)
@@ -1516,7 +1410,7 @@ object DecisionTree extends Serializable with Logging {
* centroidForCategories is a mapping: category (for the given feature) --> centroid
*/
val centroidForCategories = {
- if (isMulticlassClassification) {
+ if (isMulticlass) {
// For categorical variables in multiclass classification,
// each bin is a category. The bins are sorted and they
// are ordered by calculating the impurity of their corresponding labels.
@@ -1524,7 +1418,7 @@ object DecisionTree extends Serializable with Logging {
.groupBy(_._1)
.mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble))
.map(x => (x._1, x._2.values.toArray))
- .map(x => (x._1, strategy.impurity.calculate(x._2, x._2.sum)))
+ .map(x => (x._1, metadata.impurity.calculate(x._2, x._2.sum)))
} else { // regression or binary classification
// For categorical variables in regression and binary classification,
// each bin is a category. The bins are sorted and they
@@ -1576,7 +1470,7 @@ object DecisionTree extends Serializable with Logging {
// Find all bins.
featureIndex = 0
while (featureIndex < numFeatures) {
- val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+ val isFeatureContinuous = metadata.isContinuous(featureIndex)
if (isFeatureContinuous) { // Bins for categorical variables are already assigned.
bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
splits(featureIndex)(0), Continuous, Double.MinValue)
@@ -1590,7 +1484,7 @@ object DecisionTree extends Serializable with Logging {
}
featureIndex += 1
}
- (splits,bins)
+ (splits, bins)
case MinMax =>
throw new UnsupportedOperationException("minmax not supported yet.")
case ApproxHist =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
new file mode 100644
index 0000000000..d9eda354dc
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impl
+
+import scala.collection.mutable
+
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.impurity.Impurity
+import org.apache.spark.rdd.RDD
+
+
+/**
+ * Learning and dataset metadata for DecisionTree.
+ *
+ * @param numClasses For classification: labels can take values {0, ..., numClasses - 1}.
+ * For regression: fixed at 0 (no meaning).
+ * @param featureArity Map: categorical feature index --> arity.
+ * I.e., the feature takes values in {0, ..., arity - 1}.
+ */
+private[tree] class DecisionTreeMetadata(
+ val numFeatures: Int,
+ val numExamples: Long,
+ val numClasses: Int,
+ val maxBins: Int,
+ val featureArity: Map[Int, Int],
+ val unorderedFeatures: Set[Int],
+ val impurity: Impurity,
+ val quantileStrategy: QuantileStrategy) extends Serializable {
+
+ def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)
+
+ def isClassification: Boolean = numClasses >= 2
+
+ def isMulticlass: Boolean = numClasses > 2
+
+ def isMulticlassWithCategoricalFeatures: Boolean = isMulticlass && (featureArity.size > 0)
+
+ def isCategorical(featureIndex: Int): Boolean = featureArity.contains(featureIndex)
+
+ def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex)
+
+}
+
+private[tree] object DecisionTreeMetadata {
+
+ def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = {
+
+ val numFeatures = input.take(1)(0).features.size
+ val numExamples = input.count()
+ val numClasses = strategy.algo match {
+ case Classification => strategy.numClassesForClassification
+ case Regression => 0
+ }
+
+ val maxBins = math.min(strategy.maxBins, numExamples).toInt
+ val log2MaxBinsp1 = math.log(maxBins + 1) / math.log(2.0)
+
+ val unorderedFeatures = new mutable.HashSet[Int]()
+ if (numClasses > 2) {
+ strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
+ if (k - 1 < log2MaxBinsp1) {
+ // Note: The above check is equivalent to checking:
+ // numUnorderedBins = (1 << k - 1) - 1 < maxBins
+ unorderedFeatures.add(f)
+ } else {
+ // TODO: Allow this case, where we simply will know nothing about some categories?
+ require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
+ s"in categorical features (>= $k)")
+ }
+ }
+ } else {
+ strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
+ require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
+ s"in categorical features (>= $k)")
+ }
+ }
+
+ new DecisionTreeMetadata(numFeatures, numExamples, numClasses, maxBins,
+ strategy.categoricalFeaturesInfo, unorderedFeatures.toSet,
+ strategy.impurity, strategy.quantileCalculationStrategy)
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
index ccac1031fd..170e43e222 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
@@ -18,7 +18,6 @@
package org.apache.spark.mllib.tree.impl
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.model.Bin
import org.apache.spark.rdd.RDD
@@ -48,50 +47,35 @@ private[tree] object TreePoint {
* Convert an input dataset into its TreePoint representation,
* binning feature values in preparation for DecisionTree training.
* @param input Input dataset.
- * @param strategy DecisionTree training info, used for dataset metadata.
* @param bins Bins for features, of size (numFeatures, numBins).
+ * @param metadata Learning and dataset metadata
* @return TreePoint dataset representation
*/
def convertToTreeRDD(
input: RDD[LabeledPoint],
- strategy: Strategy,
- bins: Array[Array[Bin]]): RDD[TreePoint] = {
+ bins: Array[Array[Bin]],
+ metadata: DecisionTreeMetadata): RDD[TreePoint] = {
input.map { x =>
- TreePoint.labeledPointToTreePoint(x, strategy.isMulticlassClassification, bins,
- strategy.categoricalFeaturesInfo)
+ TreePoint.labeledPointToTreePoint(x, bins, metadata)
}
}
/**
* Convert one LabeledPoint into its TreePoint representation.
* @param bins Bins for features, of size (numFeatures, numBins).
- * @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity
*/
private def labeledPointToTreePoint(
labeledPoint: LabeledPoint,
- isMulticlassClassification: Boolean,
bins: Array[Array[Bin]],
- categoricalFeaturesInfo: Map[Int, Int]): TreePoint = {
+ metadata: DecisionTreeMetadata): TreePoint = {
val numFeatures = labeledPoint.features.size
val numBins = bins(0).size
val arr = new Array[Int](numFeatures)
var featureIndex = 0
while (featureIndex < numFeatures) {
- val featureInfo = categoricalFeaturesInfo.get(featureIndex)
- val isFeatureContinuous = featureInfo.isEmpty
- if (isFeatureContinuous) {
- arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, false,
- bins, categoricalFeaturesInfo)
- } else {
- val featureCategories = featureInfo.get
- val isSpaceSufficientForAllCategoricalSplits
- = numBins > math.pow(2, featureCategories.toInt - 1) - 1
- val isUnorderedFeature =
- isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
- arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous,
- isUnorderedFeature, bins, categoricalFeaturesInfo)
- }
+ arr(featureIndex) = findBin(featureIndex, labeledPoint, metadata.isContinuous(featureIndex),
+ metadata.isUnordered(featureIndex), bins, metadata.featureArity)
featureIndex += 1
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
index c89c1e371a..af35d88f71 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
@@ -20,15 +20,25 @@ package org.apache.spark.mllib.tree.model
import org.apache.spark.mllib.tree.configuration.FeatureType._
/**
- * Used for "binning" the features bins for faster best split calculation. For a continuous
- * feature, a bin is determined by a low and a high "split". For a categorical feature,
- * the a bin is determined using a single label value (category).
+ * Used for "binning" the features bins for faster best split calculation.
+ *
+ * For a continuous feature, the bin is determined by a low and a high split,
+ * where an example with featureValue falls into the bin s.t.
+ * lowSplit.threshold < featureValue <= highSplit.threshold.
+ *
+ * For ordered categorical features, there is a 1-1-1 correspondence between
+ * bins, splits, and feature values. The bin is determined by category/feature value.
+ * However, the bins are not necessarily ordered by feature value;
+ * they are ordered using impurity.
+ * For unordered categorical features, there is a 1-1 correspondence between bins, splits,
+ * where bins and splits correspond to subsets of feature values (in highSplit.categories).
+ *
* @param lowSplit signifying the lower threshold for the continuous feature to be
* accepted in the bin
* @param highSplit signifying the upper threshold for the continuous feature to be
* accepted in the bin
* @param featureType type of feature -- categorical or continuous
- * @param category categorical label value accepted in the bin for binary classification
+ * @param category categorical label value accepted in the bin for ordered features
*/
private[tree]
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index 3d3406b5d5..0594fd0749 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -39,7 +39,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
* @return Double prediction from the trained model
*/
def predict(features: Vector): Double = {
- topNode.predictIfLeaf(features)
+ topNode.predict(features)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala
deleted file mode 100644
index 2deaf4ae8d..0000000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala
+++ /dev/null
@@ -1,28 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.model
-
-/**
- * Filter specifying a split and type of comparison to be applied on features
- * @param split split specifying the feature index, type and threshold
- * @param comparison integer specifying <,=,>
- */
-private[tree] case class Filter(split: Split, comparison: Int) {
- // Comparison -1,0,1 signifies <.=,>
- override def toString = " split = " + split + "comparison = " + comparison
-}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 944f11c2c2..0eee626278 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -69,24 +69,24 @@ class Node (
/**
* predict value if node is not leaf
- * @param feature feature value
+ * @param features feature value
* @return predicted value
*/
- def predictIfLeaf(feature: Vector) : Double = {
+ def predict(features: Vector) : Double = {
if (isLeaf) {
predict
} else{
if (split.get.featureType == Continuous) {
- if (feature(split.get.feature) <= split.get.threshold) {
- leftNode.get.predictIfLeaf(feature)
+ if (features(split.get.feature) <= split.get.threshold) {
+ leftNode.get.predict(features)
} else {
- rightNode.get.predictIfLeaf(feature)
+ rightNode.get.predict(features)
}
} else {
- if (split.get.categories.contains(feature(split.get.feature))) {
- leftNode.get.predictIfLeaf(feature)
+ if (split.get.categories.contains(features(split.get.feature))) {
+ leftNode.get.predict(features)
} else {
- rightNode.get.predictIfLeaf(feature)
+ rightNode.get.predict(features)
}
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
index d7ffd386c0..50fb48b40d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
@@ -24,9 +24,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
* :: DeveloperApi ::
* Split applied to a feature
* @param feature feature index
- * @param threshold threshold for continuous feature
+ * @param threshold Threshold for continuous feature.
+ * Split left if feature <= threshold, else right.
* @param featureType type of feature -- categorical or continuous
- * @param categories accepted values for categorical variables
+ * @param categories Split left if categorical feature value is in this set, else right.
*/
@DeveloperApi
case class Split(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index a5c49a38dc..2f36fd9077 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -23,10 +23,10 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
-import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy}
-import org.apache.spark.mllib.tree.impl.TreePoint
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
-import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split}
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.regression.LabeledPoint
@@ -64,7 +64,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, 3, 2, 100)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(bins.length === 2)
assert(splits(0).length === 99)
@@ -82,7 +83,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(bins.length === 2)
assert(splits(0).length === 99)
@@ -162,7 +164,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
// Check splits.
@@ -279,7 +282,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 100,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
// Expecting 2^2 - 1 = 3 bins/splits
assert(splits(0)(0).feature === 0)
@@ -373,7 +377,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 100,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 10, 1-> 10))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
// 2^10 - 1 > 100, so categorical variables will be ordered
@@ -428,10 +433,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
maxDepth = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
- val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0,
- Array[List[Filter]](), splits, bins, 10)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
val split = bestSplits(0)._1
assert(split.categories.length === 1)
@@ -456,10 +462,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
maxDepth = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
- val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0,
- Array[List[Filter]](), splits, bins, 10)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
val split = bestSplits(0)._1
assert(split.categories.length === 1)
@@ -495,7 +502,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, 3, 2, 100)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
@@ -503,9 +511,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins(0).length === 100)
- val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0,
- Array[List[Filter]](), splits, bins, 10)
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._2.gain === 0)
@@ -518,7 +526,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, 3, 2, 100)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
@@ -526,9 +535,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins(0).length === 100)
- val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
- val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0,
- Array[List[Filter]](), splits, bins, 10)
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._2.gain === 0)
@@ -542,7 +551,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
@@ -550,9 +560,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins(0).length === 100)
- val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
- val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0,
- Array[List[Filter]](), splits, bins, 10)
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._2.gain === 0)
@@ -566,7 +576,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
@@ -574,9 +585,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins(0).length === 100)
- val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
- val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0,
- Array[List[Filter]](), splits, bins, 10)
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._2.gain === 0)
@@ -590,7 +601,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
@@ -598,14 +610,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins(0).length === 100)
- val leftFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()), -1)
- val rightFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()) ,1)
- val filters = Array[List[Filter]](List(), List(leftFilter), List(rightFilter))
+ // Train a 1-node model
+ val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100)
+ val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
+ val nodes: Array[Node] = new Array[Node](7)
+ nodes(0) = modelOneNode.topNode
+ nodes(0).leftNode = None
+ nodes(0).rightNode = None
+
val parentImpurities = Array(0.5, 0.5, 0.5)
// Single group second level tree construction.
- val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
- val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, filters,
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1, nodes,
splits, bins, 10)
assert(bestSplits.length === 2)
assert(bestSplits(0)._2.gain > 0)
@@ -613,8 +630,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
// maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second
// level tree construction.
- val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1,
- filters, splits, bins, 0)
+ val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1,
+ nodes, splits, bins, 0)
assert(bestSplitsWithGroups.length === 2)
assert(bestSplitsWithGroups(0)._2.gain > 0)
assert(bestSplitsWithGroups(1)._2.gain > 0)
@@ -629,19 +646,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity)
assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict)
}
-
}
test("stump with categorical variables for multiclass classification") {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
- val input = sc.parallelize(arr)
+ val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
assert(strategy.isMulticlassClassification)
- val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
- val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
- Array[List[Filter]](), splits, bins, 10)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
val bestSplit = bestSplits(0)._1
@@ -657,11 +674,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0))
arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0))
arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0))
- val input = sc.parallelize(arr)
+ val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 2)
- val model = DecisionTree.train(input, strategy)
+ val model = DecisionTree.train(rdd, strategy)
validateClassifier(model, arr, 1.0)
assert(model.numNodes === 3)
assert(model.depth === 1)
@@ -688,20 +705,22 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
test("stump with categorical variables for multiclass classification, with just enough bins") {
val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
- val input = sc.parallelize(arr)
+ val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
- numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
+ numClassesForClassification = 3, maxBins = maxBins,
+ categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
assert(strategy.isMulticlassClassification)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
- val model = DecisionTree.train(input, strategy)
+ val model = DecisionTree.train(rdd, strategy)
validateClassifier(model, arr, 1.0)
assert(model.numNodes === 3)
assert(model.depth === 1)
- val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
- val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
- Array[List[Filter]](), splits, bins, 10)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
val bestSplit = bestSplits(0)._1
@@ -716,18 +735,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
test("stump with continuous variables for multiclass classification") {
val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
- val input = sc.parallelize(arr)
+ val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3)
assert(strategy.isMulticlassClassification)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
- val model = DecisionTree.train(input, strategy)
+ val model = DecisionTree.train(rdd, strategy)
validateClassifier(model, arr, 0.9)
- val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
- val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
- Array[List[Filter]](), splits, bins, 10)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
val bestSplit = bestSplits(0)._1
@@ -741,18 +761,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
test("stump with continuous + categorical variables for multiclass classification") {
val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
- val input = sc.parallelize(arr)
+ val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3))
assert(strategy.isMulticlassClassification)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
- val model = DecisionTree.train(input, strategy)
+ val model = DecisionTree.train(rdd, strategy)
validateClassifier(model, arr, 0.9)
- val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
- val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
- Array[List[Filter]](), splits, bins, 10)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
val bestSplit = bestSplits(0)._1
@@ -765,14 +786,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
test("stump with categorical variables for ordered multiclass classification") {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
- val input = sc.parallelize(arr)
+ val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
assert(strategy.isMulticlassClassification)
- val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
- val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
- Array[List[Filter]](), splits, bins, 10)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
val bestSplit = bestSplits(0)._1