aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorQiping Li <liqiping1991@gmail.com>2014-10-09 01:36:58 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-09 01:36:58 -0700
commit14f222f7f76cc93633aae27a94c0e556e289ec56 (patch)
tree4927e861ccc1be46924389295ae957cb43469d8d /mllib/src/test
parent13cab5ba44e2f8d2d2204b3b0d39d7c23a819bdb (diff)
downloadspark-14f222f7f76cc93633aae27a94c0e556e289ec56.tar.gz
spark-14f222f7f76cc93633aae27a94c0e556e289ec56.tar.bz2
spark-14f222f7f76cc93633aae27a94c0e556e289ec56.zip
[SPARK-3158][MLLIB]Avoid 1 extra aggregation for DecisionTree training
Currently, the implementation does one unnecessary aggregation step. The aggregation step for level L (to choose splits) gives enough information to set the predictions of any leaf nodes at level L+1. We can use that info and skip the aggregation step for the last level of the tree (which only has leaf nodes). ### Implementation Details Each node now has a `impurity` field and the `predict` is changed from type `Double` to type `Predict`(this can be used to compute predict probability in the future) When compute best splits for each node, we also compute impurity and predict for the child nodes, which is used to constructed newly allocated child nodes. So at level L, we have set impurity and predict for nodes at level L +1. If level L+1 is the last level, then we can avoid aggregation. What's more, calculation of parent impurity in Top nodes for each tree needs to be treated differently because we have to compute impurity and predict for them first. In `binsToBestSplit`, if current node is top node(level == 0), we calculate impurity and predict first. after finding best split, top node's predict and impurity is set to the calculated value. Non-top nodes's impurity and predict are already calculated and don't need to be recalculated again. I have considered to add a initialization step to set top nodes' impurity and predict and then we can treat all nodes in the same way, but this will need a lot of duplication of code(all the code to do seq operation(BinSeqOp) needs to be duplicated), so I choose the current way. CC mengxr manishamde jkbradley, please help me review this, thanks. Author: Qiping Li <liqiping1991@gmail.com> Closes #2708 from chouqin/avoid-agg and squashes the following commits: 8e269ea [Qiping Li] adjust code and comments eefeef1 [Qiping Li] adjust comments and check child nodes' impurity c41b1b6 [Qiping Li] fix pyspark unit test 7ad7a71 [Qiping Li] fix unit test 822c912 [Qiping Li] add comments and unit test e41d715 [Qiping Li] fix bug in test suite 6cc0333 [Qiping Li] SPARK-3158: Avoid 1 extra aggregation for DecisionTree training
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala102
1 files changed, 94 insertions, 8 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index a48ed71a1c..98a72b0c4d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -253,7 +253,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val stats = rootNode.stats.get
assert(stats.gain > 0)
- assert(rootNode.predict === 1)
+ assert(rootNode.predict.predict === 1)
assert(stats.impurity > 0.2)
}
@@ -282,7 +282,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val stats = rootNode.stats.get
assert(stats.gain > 0)
- assert(rootNode.predict === 0.6)
+ assert(rootNode.predict.predict === 0.6)
assert(stats.impurity > 0.2)
}
@@ -352,7 +352,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(stats.gain === 0)
assert(stats.leftImpurity === 0)
assert(stats.rightImpurity === 0)
- assert(rootNode.predict === 1)
+ assert(rootNode.predict.predict === 1)
}
test("Binary classification stump with fixed label 0 for Entropy") {
@@ -377,7 +377,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(stats.gain === 0)
assert(stats.leftImpurity === 0)
assert(stats.rightImpurity === 0)
- assert(rootNode.predict === 0)
+ assert(rootNode.predict.predict === 0)
}
test("Binary classification stump with fixed label 1 for Entropy") {
@@ -402,7 +402,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(stats.gain === 0)
assert(stats.leftImpurity === 0)
assert(stats.rightImpurity === 0)
- assert(rootNode.predict === 1)
+ assert(rootNode.predict.predict === 1)
}
test("Second level node building with vs. without groups") {
@@ -471,7 +471,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(stats1.impurity === stats2.impurity)
assert(stats1.leftImpurity === stats2.leftImpurity)
assert(stats1.rightImpurity === stats2.rightImpurity)
- assert(children1(i).predict === children2(i).predict)
+ assert(children1(i).predict.predict === children2(i).predict.predict)
}
}
@@ -646,7 +646,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val model = DecisionTree.train(rdd, strategy)
assert(model.topNode.isLeaf)
- assert(model.topNode.predict == 0.0)
+ assert(model.topNode.predict.predict == 0.0)
val predicts = rdd.map(p => model.predict(p.features)).collect()
predicts.foreach { predict =>
assert(predict == 0.0)
@@ -693,7 +693,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val model = DecisionTree.train(input, strategy)
assert(model.topNode.isLeaf)
- assert(model.topNode.predict == 0.0)
+ assert(model.topNode.predict.predict == 0.0)
val predicts = input.map(p => model.predict(p.features)).collect()
predicts.foreach { predict =>
assert(predict == 0.0)
@@ -705,6 +705,92 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val gain = rootNode.stats.get
assert(gain == InformationGainStats.invalidInformationGainStats)
}
+
+ test("Avoid aggregation on the last level") {
+ val arr = new Array[LabeledPoint](4)
+ arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
+ arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
+ arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
+ arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
+ val input = sc.parallelize(arr)
+
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
+ numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3))
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+
+ val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+ val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+
+ val topNode = Node.emptyNode(nodeIndex = 1)
+ assert(topNode.predict.predict === Double.MinValue)
+ assert(topNode.impurity === -1.0)
+ assert(topNode.isLeaf === false)
+
+ val nodesForGroup = Map((0, Array(topNode)))
+ val treeToNodeToIndexInfo = Map((0, Map(
+ (topNode.id, new RandomForest.NodeIndexInfo(0, None))
+ )))
+ val nodeQueue = new mutable.Queue[(Int, Node)]()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
+ nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
+
+ // don't enqueue leaf nodes into node queue
+ assert(nodeQueue.isEmpty)
+
+ // set impurity and predict for topNode
+ assert(topNode.predict.predict !== Double.MinValue)
+ assert(topNode.impurity !== -1.0)
+
+ // set impurity and predict for child nodes
+ assert(topNode.leftNode.get.predict.predict === 0.0)
+ assert(topNode.rightNode.get.predict.predict === 1.0)
+ assert(topNode.leftNode.get.impurity === 0.0)
+ assert(topNode.rightNode.get.impurity === 0.0)
+ }
+
+ test("Avoid aggregation if impurity is 0.0") {
+ val arr = new Array[LabeledPoint](4)
+ arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
+ arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
+ arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
+ arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
+ val input = sc.parallelize(arr)
+
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+ numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3))
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+
+ val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+ val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+
+ val topNode = Node.emptyNode(nodeIndex = 1)
+ assert(topNode.predict.predict === Double.MinValue)
+ assert(topNode.impurity === -1.0)
+ assert(topNode.isLeaf === false)
+
+ val nodesForGroup = Map((0, Array(topNode)))
+ val treeToNodeToIndexInfo = Map((0, Map(
+ (topNode.id, new RandomForest.NodeIndexInfo(0, None))
+ )))
+ val nodeQueue = new mutable.Queue[(Int, Node)]()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
+ nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
+
+ // don't enqueue a node into node queue if its impurity is 0.0
+ assert(nodeQueue.isEmpty)
+
+ // set impurity and predict for topNode
+ assert(topNode.predict.predict !== Double.MinValue)
+ assert(topNode.impurity !== -1.0)
+
+ // set impurity and predict for child nodes
+ assert(topNode.leftNode.get.predict.predict === 0.0)
+ assert(topNode.rightNode.get.predict.predict === 1.0)
+ assert(topNode.leftNode.get.impurity === 0.0)
+ assert(topNode.rightNode.get.impurity === 0.0)
+ }
}
object DecisionTreeSuite {