aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorqiping.lqp <qiping.lqp@alibaba-inc.com>2014-09-10 15:37:10 -0700
committerXiangrui Meng <meng@databricks.com>2014-09-10 15:37:10 -0700
commit79cdb9b64ad2fa3ab7f2c221766d36658b917c40 (patch)
tree5469de1d1eb62366177e41202c69a8198c08f377
parent558962a83fb0758ab5c13ff4ea58cc96c29cbbcc (diff)
downloadspark-79cdb9b64ad2fa3ab7f2c221766d36658b917c40.tar.gz
spark-79cdb9b64ad2fa3ab7f2c221766d36658b917c40.tar.bz2
spark-79cdb9b64ad2fa3ab7f2c221766d36658b917c40.zip
[SPARK-2207][SPARK-3272][MLLib]Add minimum information gain and minimum instances per node as training parameters for decision tree.
These two parameters can act as early stop rules to do pre-pruning. When a split cause cause left or right child to have less than `minInstancesPerNode` or has less information gain than `minInfoGain`, current node will not be split by this split. When there is no possible splits that satisfy requirements, there is no useful information gain stats, but we still need to calculate the predict value for current node. So I separated calculation of predict from calculation of information gain, which can also save computation when the number of possible splits is large. Please see [SPARK-3272](https://issues.apache.org/jira/browse/SPARK-3272) for more details. CC: mengxr manishamde jkbradley, please help me review this, thanks. Author: qiping.lqp <qiping.lqp@alibaba-inc.com> Author: chouqin <liqiping1991@gmail.com> Closes #2332 from chouqin/dt-preprune and squashes the following commits: f1d11d1 [chouqin] fix typo c7ebaf1 [chouqin] fix typo 39f9b60 [chouqin] change edge `minInstancesPerNode` to 2 and add one more test 0278a11 [chouqin] remove `noSplit` and set `Predict` private to tree d593ec7 [chouqin] fix docs and change minInstancesPerNode to 1 efcc736 [qiping.lqp] fix bug 10b8012 [qiping.lqp] fix style 6728fad [qiping.lqp] minor fix: remove empty lines bb465ca [qiping.lqp] Merge branch 'master' of https://github.com/apache/spark into dt-preprune cadd569 [qiping.lqp] add api docs 46b891f [qiping.lqp] fix bug e72c7e4 [qiping.lqp] add comments 845c6fa [qiping.lqp] fix style f195e83 [qiping.lqp] fix style 987cbf4 [qiping.lqp] fix bug ff34845 [qiping.lqp] separate calculation of predict of node from calculation of info gain ac42378 [qiping.lqp] add min info gain and min instances per node parameters in decision tree
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala72
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala20
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala36
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala103
7 files changed, 213 insertions, 36 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index d1309b2b20..98596569b8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -130,7 +130,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// Find best split for all nodes at a level.
timer.start("findBestSplits")
- val splitsStatsForLevel: Array[(Split, InformationGainStats)] =
+ val splitsStatsForLevel: Array[(Split, InformationGainStats, Predict)] =
DecisionTree.findBestSplits(treeInput, parentImpurities,
metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
timer.stop("findBestSplits")
@@ -143,8 +143,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
timer.start("extractNodeInfo")
val split = nodeSplitStats._1
val stats = nodeSplitStats._2
+ val predict = nodeSplitStats._3.predict
val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
- val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
+ val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats))
logDebug("Node = " + node)
nodes(nodeIndex) = node
timer.stop("extractNodeInfo")
@@ -425,7 +426,7 @@ object DecisionTree extends Serializable with Logging {
splits: Array[Array[Split]],
bins: Array[Array[Bin]],
maxLevelForSingleGroup: Int,
- timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = {
+ timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats, Predict)] = {
// split into groups to avoid memory overflow during aggregation
if (level > maxLevelForSingleGroup) {
// When information for all nodes at a given level cannot be stored in memory,
@@ -434,7 +435,7 @@ object DecisionTree extends Serializable with Logging {
// numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
val numGroups = 1 << level - maxLevelForSingleGroup
logDebug("numGroups = " + numGroups)
- var bestSplits = new Array[(Split, InformationGainStats)](0)
+ var bestSplits = new Array[(Split, InformationGainStats, Predict)](0)
// Iterate over each group of nodes at a level.
var groupIndex = 0
while (groupIndex < numGroups) {
@@ -605,7 +606,7 @@ object DecisionTree extends Serializable with Logging {
bins: Array[Array[Bin]],
timer: TimeTracker,
numGroups: Int = 1,
- groupIndex: Int = 0): Array[(Split, InformationGainStats)] = {
+ groupIndex: Int = 0): Array[(Split, InformationGainStats, Predict)] = {
/*
* The high-level descriptions of the best split optimizations are noted here.
@@ -705,7 +706,7 @@ object DecisionTree extends Serializable with Logging {
// Calculate best splits for all nodes at a given level
timer.start("chooseSplits")
- val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
+ val bestSplits = new Array[(Split, InformationGainStats, Predict)](numNodes)
// Iterating over all nodes at this level
var nodeIndex = 0
while (nodeIndex < numNodes) {
@@ -734,28 +735,27 @@ object DecisionTree extends Serializable with Logging {
topImpurity: Double,
level: Int,
metadata: DecisionTreeMetadata): InformationGainStats = {
-
val leftCount = leftImpurityCalculator.count
val rightCount = rightImpurityCalculator.count
- val totalCount = leftCount + rightCount
- if (totalCount == 0) {
- // Return arbitrary prediction.
- return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
+ // If left child or right child doesn't satisfy minimum instances per node,
+ // then this split is invalid, return invalid information gain stats.
+ if ((leftCount < metadata.minInstancesPerNode) ||
+ (rightCount < metadata.minInstancesPerNode)) {
+ return InformationGainStats.invalidInformationGainStats
}
- val parentNodeAgg = leftImpurityCalculator.copy
- parentNodeAgg.add(rightImpurityCalculator)
+ val totalCount = leftCount + rightCount
+
// impurity of parent node
val impurity = if (level > 0) {
topImpurity
} else {
+ val parentNodeAgg = leftImpurityCalculator.copy
+ parentNodeAgg.add(rightImpurityCalculator)
parentNodeAgg.calculate()
}
- val predict = parentNodeAgg.predict
- val prob = parentNodeAgg.prob(predict)
-
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
val rightImpurity = rightImpurityCalculator.calculate()
@@ -764,7 +764,31 @@ object DecisionTree extends Serializable with Logging {
val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
- new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
+ // if information gain doesn't satisfy minimum information gain,
+ // then this split is invalid, return invalid information gain stats.
+ if (gain < metadata.minInfoGain) {
+ return InformationGainStats.invalidInformationGainStats
+ }
+
+ new InformationGainStats(gain, impurity, leftImpurity, rightImpurity)
+ }
+
+ /**
+ * Calculate predict value for current node, given stats of any split.
+ * Note that this function is called only once for each node.
+ * @param leftImpurityCalculator left node aggregates for a split
+ * @param rightImpurityCalculator right node aggregates for a node
+ * @return predict value for current node
+ */
+ private def calculatePredict(
+ leftImpurityCalculator: ImpurityCalculator,
+ rightImpurityCalculator: ImpurityCalculator): Predict = {
+ val parentNodeAgg = leftImpurityCalculator.copy
+ parentNodeAgg.add(rightImpurityCalculator)
+ val predict = parentNodeAgg.predict
+ val prob = parentNodeAgg.prob(predict)
+
+ new Predict(predict, prob)
}
/**
@@ -780,12 +804,15 @@ object DecisionTree extends Serializable with Logging {
nodeImpurity: Double,
level: Int,
metadata: DecisionTreeMetadata,
- splits: Array[Array[Split]]): (Split, InformationGainStats) = {
+ splits: Array[Array[Split]]): (Split, InformationGainStats, Predict) = {
logDebug("node impurity = " + nodeImpurity)
+ // calculate predict only once
+ var predict: Option[Predict] = None
+
// For each (feature, split), calculate the gain, and select the best (feature, split).
- Range(0, metadata.numFeatures).map { featureIndex =>
+ val (bestSplit, bestSplitStats) = Range(0, metadata.numFeatures).map { featureIndex =>
val numSplits = metadata.numSplits(featureIndex)
if (metadata.isContinuous(featureIndex)) {
// Cumulative sum (scanLeft) of bin statistics.
@@ -803,6 +830,7 @@ object DecisionTree extends Serializable with Logging {
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
+ predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
val gainStats =
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
(splitIdx, gainStats)
@@ -816,6 +844,7 @@ object DecisionTree extends Serializable with Logging {
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
+ predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
val gainStats =
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
(splitIndex, gainStats)
@@ -887,6 +916,7 @@ object DecisionTree extends Serializable with Logging {
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
+ predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
val gainStats =
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
(splitIndex, gainStats)
@@ -898,6 +928,10 @@ object DecisionTree extends Serializable with Logging {
(bestFeatureSplit, bestFeatureGainStats)
}
}.maxBy(_._2.gain)
+
+ require(predict.isDefined, "must calculate predict for each node")
+
+ (bestSplit, bestSplitStats, predict.get)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 23f74d5360..987fe632c9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -49,6 +49,13 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* k) implies the feature n is categorical with k categories 0,
* 1, 2, ... , k-1. It's important to note that features are
* zero-indexed.
+ * @param minInstancesPerNode Minimum number of instances each child must have after split.
+ * Default value is 1. If a split cause left or right child
+ * to have less than minInstancesPerNode,
+ * this split will not be considered as a valid split.
+ * @param minInfoGain Minimum information gain a split must get. Default value is 0.0.
+ * If a split has less information gain than minInfoGain,
+ * this split will not be considered as a valid split.
* @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
* 256 MB.
*/
@@ -61,6 +68,8 @@ class Strategy (
val maxBins: Int = 32,
val quantileCalculationStrategy: QuantileStrategy = Sort,
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
+ val minInstancesPerNode: Int = 1,
+ val minInfoGain: Double = 0.0,
val maxMemoryInMB: Int = 256) extends Serializable {
if (algo == Classification) {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index e95add7558..5ceaa8154d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -45,7 +45,9 @@ private[tree] class DecisionTreeMetadata(
val unorderedFeatures: Set[Int],
val numBins: Array[Int],
val impurity: Impurity,
- val quantileStrategy: QuantileStrategy) extends Serializable {
+ val quantileStrategy: QuantileStrategy,
+ val minInstancesPerNode: Int,
+ val minInfoGain: Double) extends Serializable {
def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)
@@ -127,7 +129,8 @@ private[tree] object DecisionTreeMetadata {
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
- strategy.impurity, strategy.quantileCalculationStrategy)
+ strategy.impurity, strategy.quantileCalculationStrategy,
+ strategy.minInstancesPerNode, strategy.minInfoGain)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index fb12298e0f..f3e2619bd8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -26,20 +26,26 @@ import org.apache.spark.annotation.DeveloperApi
* @param impurity current node impurity
* @param leftImpurity left node impurity
* @param rightImpurity right node impurity
- * @param predict predicted value
- * @param prob probability of the label (classification only)
*/
@DeveloperApi
class InformationGainStats(
val gain: Double,
val impurity: Double,
val leftImpurity: Double,
- val rightImpurity: Double,
- val predict: Double,
- val prob: Double = 0.0) extends Serializable {
+ val rightImpurity: Double) extends Serializable {
override def toString = {
- "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob = %f"
- .format(gain, impurity, leftImpurity, rightImpurity, predict, prob)
+ "gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
+ .format(gain, impurity, leftImpurity, rightImpurity)
}
}
+
+
+private[tree] object InformationGainStats {
+ /**
+ * An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to
+ * denote that current split doesn't satisfies minimum info gain or
+ * minimum number of instances per node.
+ */
+ val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
new file mode 100644
index 0000000000..6fac2be279
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * :: DeveloperApi ::
+ * Predicted value for a node
+ * @param predict predicted value
+ * @param prob probability of the label (classification only)
+ */
+@DeveloperApi
+private[tree] class Predict(
+ val predict: Double,
+ val prob: Double = 0.0) extends Serializable{
+
+ override def toString = {
+ "predict = %f, prob = %f".format(predict, prob)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
index 50fb48b40d..b7a85f5854 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
@@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree.model
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
+import org.apache.spark.mllib.tree.configuration.FeatureType
+import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
/**
* :: DeveloperApi ::
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 69482f2acb..fd8547c166 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
-import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node}
+import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node}
import org.apache.spark.mllib.util.LocalSparkContext
class DecisionTreeSuite extends FunSuite with LocalSparkContext {
@@ -279,9 +279,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(split.threshold === Double.MinValue)
val stats = bestSplits(0)._2
+ val predict = bestSplits(0)._3
assert(stats.gain > 0)
- assert(stats.predict === 1)
- assert(stats.prob === 0.6)
+ assert(predict.predict === 1)
+ assert(predict.prob === 0.6)
assert(stats.impurity > 0.2)
}
@@ -312,8 +313,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(split.threshold === Double.MinValue)
val stats = bestSplits(0)._2
+ val predict = bestSplits(0)._3.predict
assert(stats.gain > 0)
- assert(stats.predict === 0.6)
+ assert(predict === 0.6)
assert(stats.impurity > 0.2)
}
@@ -387,7 +389,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(0)._2.gain === 0)
assert(bestSplits(0)._2.leftImpurity === 0)
assert(bestSplits(0)._2.rightImpurity === 0)
- assert(bestSplits(0)._2.predict === 1)
+ assert(bestSplits(0)._3.predict === 1)
}
test("Binary classification stump with fixed label 0 for Entropy") {
@@ -414,7 +416,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(0)._2.gain === 0)
assert(bestSplits(0)._2.leftImpurity === 0)
assert(bestSplits(0)._2.rightImpurity === 0)
- assert(bestSplits(0)._2.predict === 0)
+ assert(bestSplits(0)._3.predict === 0)
}
test("Binary classification stump with fixed label 1 for Entropy") {
@@ -441,7 +443,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(0)._2.gain === 0)
assert(bestSplits(0)._2.leftImpurity === 0)
assert(bestSplits(0)._2.rightImpurity === 0)
- assert(bestSplits(0)._2.predict === 1)
+ assert(bestSplits(0)._3.predict === 1)
}
test("Second level node building with vs. without groups") {
@@ -490,7 +492,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity)
assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity)
assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity)
- assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict)
+ assert(bestSplits(i)._3.predict === bestSplitsWithGroups(i)._3.predict)
}
}
@@ -674,6 +676,91 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
validateClassifier(model, arr, 0.6)
}
+ test("split must satisfy min instances per node requirements") {
+ val arr = new Array[LabeledPoint](3)
+ arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
+ arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
+ arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))
+
+ val input = sc.parallelize(arr)
+ val strategy = new Strategy(algo = Classification, impurity = Gini,
+ maxDepth = 2, numClassesForClassification = 2, minInstancesPerNode = 2)
+
+ val model = DecisionTree.train(input, strategy)
+ assert(model.topNode.isLeaf)
+ assert(model.topNode.predict == 0.0)
+ val predicts = input.map(p => model.predict(p.features)).collect()
+ predicts.foreach { predict =>
+ assert(predict == 0.0)
+ }
+
+ // test for findBestSplits when no valid split can be found
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+ val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
+
+ assert(bestSplits.length == 1)
+ val bestInfoStats = bestSplits(0)._2
+ assert(bestInfoStats == InformationGainStats.invalidInformationGainStats)
+ }
+
+ test("don't choose split that doesn't satisfy min instance per node requirements") {
+ // if a split doesn't satisfy min instances per node requirements,
+ // this split is invalid, even though the information gain of split is large.
+ val arr = new Array[LabeledPoint](4)
+ arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0))
+ arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0, 1.0))
+ arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
+ arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
+
+ val input = sc.parallelize(arr)
+ val strategy = new Strategy(algo = Classification, impurity = Gini,
+ maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2),
+ numClassesForClassification = 2, minInstancesPerNode = 2)
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+ val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
+
+ assert(bestSplits.length == 1)
+ val bestSplit = bestSplits(0)._1
+ val bestSplitStats = bestSplits(0)._1
+ assert(bestSplit.feature == 1)
+ assert(bestSplitStats != InformationGainStats.invalidInformationGainStats)
+ }
+
+ test("split must satisfy min info gain requirements") {
+ val arr = new Array[LabeledPoint](3)
+ arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
+ arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
+ arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))
+
+ val input = sc.parallelize(arr)
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+ numClassesForClassification = 2, minInfoGain = 1.0)
+
+ val model = DecisionTree.train(input, strategy)
+ assert(model.topNode.isLeaf)
+ assert(model.topNode.predict == 0.0)
+ val predicts = input.map(p => model.predict(p.features)).collect()
+ predicts.foreach { predict =>
+ assert(predict == 0.0)
+ }
+
+ // test for findBestSplits when no valid split can be found
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+ val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
+
+ assert(bestSplits.length == 1)
+ val bestInfoStats = bestSplits(0)._2
+ assert(bestInfoStats == InformationGainStats.invalidInformationGainStats)
+ }
}
object DecisionTreeSuite {