aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala97
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala37
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala102
4 files changed, 197 insertions, 48 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index b311d10023..03eeaa7077 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -532,6 +532,14 @@ object DecisionTree extends Serializable with Logging {
Some(mutableNodeToFeatures.toMap)
}
+ // array of nodes to train indexed by node index in group
+ val nodes = new Array[Node](numNodes)
+ nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
+ nodesForTree.foreach { node =>
+ nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
+ }
+ }
+
// Calculate best splits for all nodes in the group
timer.start("chooseSplits")
@@ -568,7 +576,7 @@ object DecisionTree extends Serializable with Logging {
// find best split for each node
val (split: Split, stats: InformationGainStats, predict: Predict) =
- binsToBestSplit(aggStats, splits, featuresForNode)
+ binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
(nodeIndex, (split, stats, predict))
}.collectAsMap()
@@ -587,17 +595,30 @@ object DecisionTree extends Serializable with Logging {
// Extract info for this node. Create children if not leaf.
val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth)
assert(node.id == nodeIndex)
- node.predict = predict.predict
+ node.predict = predict
node.isLeaf = isLeaf
node.stats = Some(stats)
+ node.impurity = stats.impurity
logDebug("Node = " + node)
if (!isLeaf) {
node.split = Some(split)
- node.leftNode = Some(Node.emptyNode(Node.leftChildIndex(nodeIndex)))
- node.rightNode = Some(Node.emptyNode(Node.rightChildIndex(nodeIndex)))
- nodeQueue.enqueue((treeIndex, node.leftNode.get))
- nodeQueue.enqueue((treeIndex, node.rightNode.get))
+ val childIsLeaf = (Node.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
+ val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
+ val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
+ node.leftNode = Some(Node(Node.leftChildIndex(nodeIndex),
+ stats.leftPredict, stats.leftImpurity, leftChildIsLeaf))
+ node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex),
+ stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
+
+ // enqueue left child and right child if they are not leaves
+ if (!leftChildIsLeaf) {
+ nodeQueue.enqueue((treeIndex, node.leftNode.get))
+ }
+ if (!rightChildIsLeaf) {
+ nodeQueue.enqueue((treeIndex, node.rightNode.get))
+ }
+
logDebug("leftChildIndex = " + node.leftNode.get.id +
", impurity = " + stats.leftImpurity)
logDebug("rightChildIndex = " + node.rightNode.get.id +
@@ -617,7 +638,8 @@ object DecisionTree extends Serializable with Logging {
private def calculateGainForSplit(
leftImpurityCalculator: ImpurityCalculator,
rightImpurityCalculator: ImpurityCalculator,
- metadata: DecisionTreeMetadata): InformationGainStats = {
+ metadata: DecisionTreeMetadata,
+ impurity: Double): InformationGainStats = {
val leftCount = leftImpurityCalculator.count
val rightCount = rightImpurityCalculator.count
@@ -630,11 +652,6 @@ object DecisionTree extends Serializable with Logging {
val totalCount = leftCount + rightCount
- val parentNodeAgg = leftImpurityCalculator.copy
- parentNodeAgg.add(rightImpurityCalculator)
-
- val impurity = parentNodeAgg.calculate()
-
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
val rightImpurity = rightImpurityCalculator.calculate()
@@ -649,7 +666,18 @@ object DecisionTree extends Serializable with Logging {
return InformationGainStats.invalidInformationGainStats
}
- new InformationGainStats(gain, impurity, leftImpurity, rightImpurity)
+ // calculate left and right predict
+ val leftPredict = calculatePredict(leftImpurityCalculator)
+ val rightPredict = calculatePredict(rightImpurityCalculator)
+
+ new InformationGainStats(gain, impurity, leftImpurity, rightImpurity,
+ leftPredict, rightPredict)
+ }
+
+ private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = {
+ val predict = impurityCalculator.predict
+ val prob = impurityCalculator.prob(predict)
+ new Predict(predict, prob)
}
/**
@@ -657,17 +685,17 @@ object DecisionTree extends Serializable with Logging {
* Note that this function is called only once for each node.
* @param leftImpurityCalculator left node aggregates for a split
* @param rightImpurityCalculator right node aggregates for a split
- * @return predict value for current node
+ * @return predict value and impurity for current node
*/
- private def calculatePredict(
+ private def calculatePredictImpurity(
leftImpurityCalculator: ImpurityCalculator,
- rightImpurityCalculator: ImpurityCalculator): Predict = {
+ rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = {
val parentNodeAgg = leftImpurityCalculator.copy
parentNodeAgg.add(rightImpurityCalculator)
- val predict = parentNodeAgg.predict
- val prob = parentNodeAgg.prob(predict)
+ val predict = calculatePredict(parentNodeAgg)
+ val impurity = parentNodeAgg.calculate()
- new Predict(predict, prob)
+ (predict, impurity)
}
/**
@@ -678,10 +706,16 @@ object DecisionTree extends Serializable with Logging {
private def binsToBestSplit(
binAggregates: DTStatsAggregator,
splits: Array[Array[Split]],
- featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = {
+ featuresForNode: Option[Array[Int]],
+ node: Node): (Split, InformationGainStats, Predict) = {
- // calculate predict only once
- var predict: Option[Predict] = None
+ // calculate predict and impurity if current node is top node
+ val level = Node.indexToLevel(node.id)
+ var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) {
+ None
+ } else {
+ Some((node.predict, node.impurity))
+ }
// For each (feature, split), calculate the gain, and select the best (feature, split).
val (bestSplit, bestSplitStats) =
@@ -708,9 +742,10 @@ object DecisionTree extends Serializable with Logging {
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
- predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
+ predictWithImpurity = Some(predictWithImpurity.getOrElse(
+ calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
- rightChildStats, binAggregates.metadata)
+ rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIdx, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
@@ -722,9 +757,10 @@ object DecisionTree extends Serializable with Logging {
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
- predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
+ predictWithImpurity = Some(predictWithImpurity.getOrElse(
+ calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
- rightChildStats, binAggregates.metadata)
+ rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
@@ -794,9 +830,10 @@ object DecisionTree extends Serializable with Logging {
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
- predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
+ predictWithImpurity = Some(predictWithImpurity.getOrElse(
+ calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
- rightChildStats, binAggregates.metadata)
+ rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
val categoriesForSplit =
@@ -807,9 +844,7 @@ object DecisionTree extends Serializable with Logging {
}
}.maxBy(_._2.gain)
- assert(predict.isDefined, "must calculate predict for each node")
-
- (bestSplit, bestSplitStats, predict.get)
+ (bestSplit, bestSplitStats, predictWithImpurity.get._1)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index a89e71e115..9a50ecb550 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -26,13 +26,17 @@ import org.apache.spark.annotation.DeveloperApi
* @param impurity current node impurity
* @param leftImpurity left node impurity
* @param rightImpurity right node impurity
+ * @param leftPredict left node predict
+ * @param rightPredict right node predict
*/
@DeveloperApi
class InformationGainStats(
val gain: Double,
val impurity: Double,
val leftImpurity: Double,
- val rightImpurity: Double) extends Serializable {
+ val rightImpurity: Double,
+ val leftPredict: Predict,
+ val rightPredict: Predict) extends Serializable {
override def toString = {
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
@@ -58,5 +62,6 @@ private[tree] object InformationGainStats {
* denote that current split doesn't satisfies minimum info gain or
* minimum number of instances per node.
*/
- val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0)
+ val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0,
+ new Predict(0.0, 0.0), new Predict(0.0, 0.0))
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 56c3e25d92..2179da8dbe 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -32,7 +32,8 @@ import org.apache.spark.mllib.linalg.Vector
*
* @param id integer node id, from 1
* @param predict predicted value at the node
- * @param isLeaf whether the leaf is a node
+ * @param impurity current node impurity
+ * @param isLeaf whether the node is a leaf
* @param split split to calculate left and right nodes
* @param leftNode left child
* @param rightNode right child
@@ -41,7 +42,8 @@ import org.apache.spark.mllib.linalg.Vector
@DeveloperApi
class Node (
val id: Int,
- var predict: Double,
+ var predict: Predict,
+ var impurity: Double,
var isLeaf: Boolean,
var split: Option[Split],
var leftNode: Option[Node],
@@ -49,7 +51,7 @@ class Node (
var stats: Option[InformationGainStats]) extends Serializable with Logging {
override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " +
- "split = " + split + ", stats = " + stats
+ "impurity = " + impurity + "split = " + split + ", stats = " + stats
/**
* build the left node and right nodes if not leaf
@@ -62,6 +64,7 @@ class Node (
logDebug("id = " + id + ", split = " + split)
logDebug("stats = " + stats)
logDebug("predict = " + predict)
+ logDebug("impurity = " + impurity)
if (!isLeaf) {
leftNode = Some(nodes(Node.leftChildIndex(id)))
rightNode = Some(nodes(Node.rightChildIndex(id)))
@@ -77,7 +80,7 @@ class Node (
*/
def predict(features: Vector) : Double = {
if (isLeaf) {
- predict
+ predict.predict
} else{
if (split.get.featureType == Continuous) {
if (features(split.get.feature) <= split.get.threshold) {
@@ -109,7 +112,7 @@ class Node (
} else {
Some(rightNode.get.deepCopy())
}
- new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats)
+ new Node(id, predict, impurity, isLeaf, split, leftNodeCopy, rightNodeCopy, stats)
}
/**
@@ -154,7 +157,7 @@ class Node (
}
val prefix: String = " " * indentFactor
if (isLeaf) {
- prefix + s"Predict: $predict\n"
+ prefix + s"Predict: ${predict.predict}\n"
} else {
prefix + s"If ${splitToString(split.get, left=true)}\n" +
leftNode.get.subtreeToString(indentFactor + 1) +
@@ -170,7 +173,27 @@ private[tree] object Node {
/**
* Return a node with the given node id (but nothing else set).
*/
- def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, 0, false, None, None, None, None)
+ def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, new Predict(Double.MinValue), -1.0,
+ false, None, None, None, None)
+
+ /**
+ * Construct a node with nodeIndex, predict, impurity and isLeaf parameters.
+ * This is used in `DecisionTree.findBestSplits` to construct child nodes
+ * after finding the best splits for parent nodes.
+ * Other fields are set at next level.
+ * @param nodeIndex integer node id, from 1
+ * @param predict predicted value at the node
+ * @param impurity current node impurity
+ * @param isLeaf whether the node is a leaf
+ * @return new node instance
+ */
+ def apply(
+ nodeIndex: Int,
+ predict: Predict,
+ impurity: Double,
+ isLeaf: Boolean): Node = {
+ new Node(nodeIndex, predict, impurity, isLeaf, None, None, None, None)
+ }
/**
* Return the index of the left child of this node.
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index a48ed71a1c..98a72b0c4d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -253,7 +253,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val stats = rootNode.stats.get
assert(stats.gain > 0)
- assert(rootNode.predict === 1)
+ assert(rootNode.predict.predict === 1)
assert(stats.impurity > 0.2)
}
@@ -282,7 +282,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val stats = rootNode.stats.get
assert(stats.gain > 0)
- assert(rootNode.predict === 0.6)
+ assert(rootNode.predict.predict === 0.6)
assert(stats.impurity > 0.2)
}
@@ -352,7 +352,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(stats.gain === 0)
assert(stats.leftImpurity === 0)
assert(stats.rightImpurity === 0)
- assert(rootNode.predict === 1)
+ assert(rootNode.predict.predict === 1)
}
test("Binary classification stump with fixed label 0 for Entropy") {
@@ -377,7 +377,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(stats.gain === 0)
assert(stats.leftImpurity === 0)
assert(stats.rightImpurity === 0)
- assert(rootNode.predict === 0)
+ assert(rootNode.predict.predict === 0)
}
test("Binary classification stump with fixed label 1 for Entropy") {
@@ -402,7 +402,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(stats.gain === 0)
assert(stats.leftImpurity === 0)
assert(stats.rightImpurity === 0)
- assert(rootNode.predict === 1)
+ assert(rootNode.predict.predict === 1)
}
test("Second level node building with vs. without groups") {
@@ -471,7 +471,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(stats1.impurity === stats2.impurity)
assert(stats1.leftImpurity === stats2.leftImpurity)
assert(stats1.rightImpurity === stats2.rightImpurity)
- assert(children1(i).predict === children2(i).predict)
+ assert(children1(i).predict.predict === children2(i).predict.predict)
}
}
@@ -646,7 +646,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val model = DecisionTree.train(rdd, strategy)
assert(model.topNode.isLeaf)
- assert(model.topNode.predict == 0.0)
+ assert(model.topNode.predict.predict == 0.0)
val predicts = rdd.map(p => model.predict(p.features)).collect()
predicts.foreach { predict =>
assert(predict == 0.0)
@@ -693,7 +693,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val model = DecisionTree.train(input, strategy)
assert(model.topNode.isLeaf)
- assert(model.topNode.predict == 0.0)
+ assert(model.topNode.predict.predict == 0.0)
val predicts = input.map(p => model.predict(p.features)).collect()
predicts.foreach { predict =>
assert(predict == 0.0)
@@ -705,6 +705,92 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val gain = rootNode.stats.get
assert(gain == InformationGainStats.invalidInformationGainStats)
}
+
+ test("Avoid aggregation on the last level") {
+ val arr = new Array[LabeledPoint](4)
+ arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
+ arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
+ arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
+ arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
+ val input = sc.parallelize(arr)
+
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
+ numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3))
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+
+ val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+ val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+
+ val topNode = Node.emptyNode(nodeIndex = 1)
+ assert(topNode.predict.predict === Double.MinValue)
+ assert(topNode.impurity === -1.0)
+ assert(topNode.isLeaf === false)
+
+ val nodesForGroup = Map((0, Array(topNode)))
+ val treeToNodeToIndexInfo = Map((0, Map(
+ (topNode.id, new RandomForest.NodeIndexInfo(0, None))
+ )))
+ val nodeQueue = new mutable.Queue[(Int, Node)]()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
+ nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
+
+ // don't enqueue leaf nodes into node queue
+ assert(nodeQueue.isEmpty)
+
+ // set impurity and predict for topNode
+ assert(topNode.predict.predict !== Double.MinValue)
+ assert(topNode.impurity !== -1.0)
+
+ // set impurity and predict for child nodes
+ assert(topNode.leftNode.get.predict.predict === 0.0)
+ assert(topNode.rightNode.get.predict.predict === 1.0)
+ assert(topNode.leftNode.get.impurity === 0.0)
+ assert(topNode.rightNode.get.impurity === 0.0)
+ }
+
+ test("Avoid aggregation if impurity is 0.0") {
+ val arr = new Array[LabeledPoint](4)
+ arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
+ arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
+ arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
+ arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
+ val input = sc.parallelize(arr)
+
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+ numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3))
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+
+ val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+ val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+
+ val topNode = Node.emptyNode(nodeIndex = 1)
+ assert(topNode.predict.predict === Double.MinValue)
+ assert(topNode.impurity === -1.0)
+ assert(topNode.isLeaf === false)
+
+ val nodesForGroup = Map((0, Array(topNode)))
+ val treeToNodeToIndexInfo = Map((0, Map(
+ (topNode.id, new RandomForest.NodeIndexInfo(0, None))
+ )))
+ val nodeQueue = new mutable.Queue[(Int, Node)]()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
+ nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
+
+ // don't enqueue a node into node queue if its impurity is 0.0
+ assert(nodeQueue.isEmpty)
+
+ // set impurity and predict for topNode
+ assert(topNode.predict.predict !== Double.MinValue)
+ assert(topNode.impurity !== -1.0)
+
+ // set impurity and predict for child nodes
+ assert(topNode.leftNode.get.predict.predict === 0.0)
+ assert(topNode.rightNode.get.predict.predict === 1.0)
+ assert(topNode.leftNode.get.impurity === 0.0)
+ assert(topNode.rightNode.get.impurity === 0.0)
+ }
}
object DecisionTreeSuite {