aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala103
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala84
3 files changed, 170 insertions, 23 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 59ed01debf..0fe30a3e70 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -54,12 +54,13 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
- logDebug("numSplits = " + bins(0).length)
+ val numBins = bins(0).length
+ logDebug("numBins = " + numBins)
// depth of the decision tree
val maxDepth = strategy.maxDepth
// the max number of nodes possible given the depth of the tree
- val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1
+ val maxNumNodes = math.pow(2, maxDepth).toInt - 1
// Initialize an array to hold filters applied to points for each node.
val filters = new Array[List[Filter]](maxNumNodes)
// The filter at the top node is an empty list.
@@ -68,7 +69,28 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
val parentImpurities = new Array[Double](maxNumNodes)
// dummy value for top node (updated during first split calculation)
val nodes = new Array[Node](maxNumNodes)
+ // num features
+ val numFeatures = input.take(1)(0).features.size
+
+ // Calculate level for single group construction
+ // Max memory usage for aggregates
+ val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024
+ logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
+ val numElementsPerNode =
+ strategy.algo match {
+ case Classification => 2 * numBins * numFeatures
+ case Regression => 3 * numBins * numFeatures
+ }
+
+ logDebug("numElementsPerNode = " + numElementsPerNode)
+ val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array
+ val maxNumberOfNodesPerGroup = math.max(maxMemoryUsage / arraySizePerNode, 1)
+ logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup)
+ // nodes at a level is 2^level. level is zero indexed.
+ val maxLevelForSingleGroup = math.max(
+ (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0)
+ logDebug("max level for single group = " + maxLevelForSingleGroup)
/*
* The main idea here is to perform level-wise training of the decision tree nodes thus
@@ -88,7 +110,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// Find best split for all nodes at a level.
val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy,
- level, filters, splits, bins)
+ level, filters, splits, bins, maxLevelForSingleGroup)
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
// Extract info for nodes at the current level.
@@ -98,7 +120,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
filters)
logDebug("final best split = " + nodeSplitStats._1)
}
- require(scala.math.pow(2, level) == splitsStatsForLevel.length)
+ require(math.pow(2, level) == splitsStatsForLevel.length)
// Check whether all the nodes at the current level at leaves.
val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
logDebug("all leaf = " + allLeaf)
@@ -109,6 +131,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
}
}
+ logDebug("#####################################")
+ logDebug("Extracting tree model")
+ logDebug("#####################################")
+
// Initialize the top or root node of the tree.
val topNode = nodes(0)
// Build the full tree using the node info calculated in the level-wise best split calculations.
@@ -127,7 +153,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
nodes: Array[Node]): Unit = {
val split = nodeSplitStats._1
val stats = nodeSplitStats._2
- val nodeIndex = scala.math.pow(2, level).toInt - 1 + index
+ val nodeIndex = math.pow(2, level).toInt - 1 + index
val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1)
val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
logDebug("Node = " + node)
@@ -148,7 +174,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
var i = 0
while (i <= 1) {
// Calculate the index of the node from the node level and the index at the current level.
- val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i
+ val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i
if (level < maxDepth - 1) {
val impurity = if (i == 0) {
nodeSplitStats._2.leftImpurity
@@ -249,7 +275,8 @@ object DecisionTree extends Serializable with Logging {
private val InvalidBinIndex = -1
/**
- * Returns an array of optimal splits for all nodes at a given level
+ * Returns an array of optimal splits for all nodes at a given level. Splits the task into
+ * multiple groups if the level-wise training task could lead to memory overflow.
*
* @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
* for DecisionTree
@@ -260,6 +287,7 @@ object DecisionTree extends Serializable with Logging {
* @param filters Filters for all nodes at a given level
* @param splits possible splits for all features
* @param bins possible bins for all features
+ * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
* @return array of splits with best splits for all nodes at a given level.
*/
protected[tree] def findBestSplits(
@@ -269,7 +297,57 @@ object DecisionTree extends Serializable with Logging {
level: Int,
filters: Array[List[Filter]],
splits: Array[Array[Split]],
- bins: Array[Array[Bin]]): Array[(Split, InformationGainStats)] = {
+ bins: Array[Array[Bin]],
+ maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = {
+ // split into groups to avoid memory overflow during aggregation
+ if (level > maxLevelForSingleGroup) {
+ // When information for all nodes at a given level cannot be stored in memory,
+ // the nodes are divided into multiple groups at each level with the number of groups
+ // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10,
+ // numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
+ val numGroups = math.pow(2, (level - maxLevelForSingleGroup)).toInt
+ logDebug("numGroups = " + numGroups)
+ var bestSplits = new Array[(Split, InformationGainStats)](0)
+ // Iterate over each group of nodes at a level.
+ var groupIndex = 0
+ while (groupIndex < numGroups) {
+ val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level,
+ filters, splits, bins, numGroups, groupIndex)
+ bestSplits = Array.concat(bestSplits, bestSplitsForGroup)
+ groupIndex += 1
+ }
+ bestSplits
+ } else {
+ findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins)
+ }
+ }
+
+ /**
+ * Returns an array of optimal splits for a group of nodes at a given level
+ *
+ * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
+ * for DecisionTree
+ * @param parentImpurities Impurities for all parent nodes for the current level
+ * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
+ * parameters for construction the DecisionTree
+ * @param level Level of the tree
+ * @param filters Filters for all nodes at a given level
+ * @param splits possible splits for all features
+ * @param bins possible bins for all features
+ * @param numGroups total number of node groups at the current level. Default value is set to 1.
+ * @param groupIndex index of the node group being processed. Default value is set to 0.
+ * @return array of splits with best splits for all nodes at a given level.
+ */
+ private def findBestSplitsPerGroup(
+ input: RDD[LabeledPoint],
+ parentImpurities: Array[Double],
+ strategy: Strategy,
+ level: Int,
+ filters: Array[List[Filter]],
+ splits: Array[Array[Split]],
+ bins: Array[Array[Bin]],
+ numGroups: Int = 1,
+ groupIndex: Int = 0): Array[(Split, InformationGainStats)] = {
/*
* The high-level description for the best split optimizations are noted here.
@@ -296,7 +374,7 @@ object DecisionTree extends Serializable with Logging {
*/
// common calculations for multiple nested methods
- val numNodes = scala.math.pow(2, level).toInt
+ val numNodes = math.pow(2, level).toInt / numGroups
logDebug("numNodes = " + numNodes)
// Find the number of features by looking at the first sample.
val numFeatures = input.first().features.size
@@ -304,12 +382,15 @@ object DecisionTree extends Serializable with Logging {
val numBins = bins(0).length
logDebug("numBins = " + numBins)
+ // shift when more than one group is used at deep tree level
+ val groupShift = numNodes * groupIndex
+
/** Find the filters used before reaching the current code. */
def findParentFilters(nodeIndex: Int): List[Filter] = {
if (level == 0) {
List[Filter]()
} else {
- val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex
+ val nodeFilterIndex = math.pow(2, level).toInt - 1 + nodeIndex + groupShift
filters(nodeFilterIndex)
}
}
@@ -878,7 +959,7 @@ object DecisionTree extends Serializable with Logging {
// Iterating over all nodes at this level
var node = 0
while (node < numNodes) {
- val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node
+ val nodeImpurityIndex = math.pow(2, level).toInt - 1 + node + groupShift
val binsForNode: Array[Double] = getBinDataForNode(node)
logDebug("nodeImpurityIndex = " + nodeImpurityIndex)
val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 8767aca47c..1b505fd76e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -35,6 +35,9 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* k) implies the feature n is categorical with k categories 0,
* 1, 2, ... , k-1. It's important to note that features are
* zero-indexed.
+ * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is
+ * 128 MB.
+ *
*/
@Experimental
class Strategy (
@@ -43,4 +46,5 @@ class Strategy (
val maxDepth: Int,
val maxBins: Int = 100,
val quantileCalculationStrategy: QuantileStrategy = Sort,
- val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]()) extends Serializable
+ val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
+ val maxMemoryInMB: Int = 128) extends Serializable
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index be383aab71..35e92d71dc 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -22,7 +22,8 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.Filter
-import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.model.Split
+import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.linalg.Vectors
@@ -242,7 +243,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
- Array[List[Filter]](), splits, bins)
+ Array[List[Filter]](), splits, bins, 10)
val split = bestSplits(0)._1
assert(split.categories.length === 1)
@@ -269,7 +270,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
- Array[List[Filter]](), splits, bins)
+ Array[List[Filter]](), splits, bins, 10)
val split = bestSplits(0)._1
assert(split.categories.length === 1)
@@ -298,7 +299,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins(0).length === 100)
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
- Array[List[Filter]](), splits, bins)
+ Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._1.threshold === 10)
@@ -321,7 +322,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins(0).length === 100)
val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
- Array[List[Filter]](), splits, bins)
+ Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._1.threshold === 10)
@@ -345,7 +346,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins(0).length === 100)
val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
- Array[List[Filter]](), splits, bins)
+ Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._1.threshold === 10)
@@ -369,7 +370,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins(0).length === 100)
val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
- Array[List[Filter]](), splits, bins)
+ Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._1.threshold === 10)
@@ -378,13 +379,60 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(0)._2.rightImpurity === 0)
assert(bestSplits(0)._2.predict === 1)
}
+
+ test("test second level node building with/without groups") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Entropy, 3, 100)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins.length === 2)
+ assert(bins(0).length === 100)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+
+ val leftFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()), -1)
+ val rightFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()) ,1)
+ val filters = Array[List[Filter]](List(), List(leftFilter), List(rightFilter))
+ val parentImpurities = Array(0.5, 0.5, 0.5)
+
+ // Single group second level tree construction.
+ val bestSplits = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, filters,
+ splits, bins, 10)
+ assert(bestSplits.length === 2)
+ assert(bestSplits(0)._2.gain > 0)
+ assert(bestSplits(1)._2.gain > 0)
+
+ // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second
+ // level tree construction.
+ val bestSplitsWithGroups = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1,
+ filters, splits, bins, 0)
+ assert(bestSplitsWithGroups.length === 2)
+ assert(bestSplitsWithGroups(0)._2.gain > 0)
+ assert(bestSplitsWithGroups(1)._2.gain > 0)
+
+ // Verify whether the splits obtained using single group and multiple group level
+ // construction strategies are the same.
+ for (i <- 0 until bestSplits.length) {
+ assert(bestSplits(i)._1 === bestSplitsWithGroups(i)._1)
+ assert(bestSplits(i)._2.gain === bestSplitsWithGroups(i)._2.gain)
+ assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity)
+ assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity)
+ assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity)
+ assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict)
+ }
+
+ }
+
}
object DecisionTreeSuite {
def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
- for (i <- 0 until 1000){
+ for (i <- 0 until 1000) {
val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
arr(i) = lp
}
@@ -393,17 +441,31 @@ object DecisionTreeSuite {
def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
- for (i <- 0 until 1000){
+ for (i <- 0 until 1000) {
val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i))
arr(i) = lp
}
arr
}
+ def generateOrderedLabeledPoints(): Array[LabeledPoint] = {
+ val arr = new Array[LabeledPoint](1000)
+ for (i <- 0 until 1000) {
+ if (i < 600) {
+ val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
+ arr(i) = lp
+ } else {
+ val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i))
+ arr(i) = lp
+ }
+ }
+ arr
+ }
+
def generateCategoricalDataPoints(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
- for (i <- 0 until 1000){
- if (i < 600){
+ for (i <- 0 until 1000) {
+ if (i < 600) {
arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0))
} else {
arr(i) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0))