aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala72
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala20
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala36
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala103
7 files changed, 213 insertions, 36 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index d1309b2b20..98596569b8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -130,7 +130,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// Find best split for all nodes at a level.
timer.start("findBestSplits")
- val splitsStatsForLevel: Array[(Split, InformationGainStats)] =
+ val splitsStatsForLevel: Array[(Split, InformationGainStats, Predict)] =
DecisionTree.findBestSplits(treeInput, parentImpurities,
metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
timer.stop("findBestSplits")
@@ -143,8 +143,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
timer.start("extractNodeInfo")
val split = nodeSplitStats._1
val stats = nodeSplitStats._2
+ val predict = nodeSplitStats._3.predict
val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
- val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
+ val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats))
logDebug("Node = " + node)
nodes(nodeIndex) = node
timer.stop("extractNodeInfo")
@@ -425,7 +426,7 @@ object DecisionTree extends Serializable with Logging {
splits: Array[Array[Split]],
bins: Array[Array[Bin]],
maxLevelForSingleGroup: Int,
- timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = {
+ timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats, Predict)] = {
// split into groups to avoid memory overflow during aggregation
if (level > maxLevelForSingleGroup) {
// When information for all nodes at a given level cannot be stored in memory,
@@ -434,7 +435,7 @@ object DecisionTree extends Serializable with Logging {
// numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
val numGroups = 1 << level - maxLevelForSingleGroup
logDebug("numGroups = " + numGroups)
- var bestSplits = new Array[(Split, InformationGainStats)](0)
+ var bestSplits = new Array[(Split, InformationGainStats, Predict)](0)
// Iterate over each group of nodes at a level.
var groupIndex = 0
while (groupIndex < numGroups) {
@@ -605,7 +606,7 @@ object DecisionTree extends Serializable with Logging {
bins: Array[Array[Bin]],
timer: TimeTracker,
numGroups: Int = 1,
- groupIndex: Int = 0): Array[(Split, InformationGainStats)] = {
+ groupIndex: Int = 0): Array[(Split, InformationGainStats, Predict)] = {
/*
* The high-level descriptions of the best split optimizations are noted here.
@@ -705,7 +706,7 @@ object DecisionTree extends Serializable with Logging {
// Calculate best splits for all nodes at a given level
timer.start("chooseSplits")
- val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
+ val bestSplits = new Array[(Split, InformationGainStats, Predict)](numNodes)
// Iterating over all nodes at this level
var nodeIndex = 0
while (nodeIndex < numNodes) {
@@ -734,28 +735,27 @@ object DecisionTree extends Serializable with Logging {
topImpurity: Double,
level: Int,
metadata: DecisionTreeMetadata): InformationGainStats = {
-
val leftCount = leftImpurityCalculator.count
val rightCount = rightImpurityCalculator.count
- val totalCount = leftCount + rightCount
- if (totalCount == 0) {
- // Return arbitrary prediction.
- return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
+ // If left child or right child doesn't satisfy minimum instances per node,
+ // then this split is invalid, return invalid information gain stats.
+ if ((leftCount < metadata.minInstancesPerNode) ||
+ (rightCount < metadata.minInstancesPerNode)) {
+ return InformationGainStats.invalidInformationGainStats
}
- val parentNodeAgg = leftImpurityCalculator.copy
- parentNodeAgg.add(rightImpurityCalculator)
+ val totalCount = leftCount + rightCount
+
// impurity of parent node
val impurity = if (level > 0) {
topImpurity
} else {
+ val parentNodeAgg = leftImpurityCalculator.copy
+ parentNodeAgg.add(rightImpurityCalculator)
parentNodeAgg.calculate()
}
- val predict = parentNodeAgg.predict
- val prob = parentNodeAgg.prob(predict)
-
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
val rightImpurity = rightImpurityCalculator.calculate()
@@ -764,7 +764,31 @@ object DecisionTree extends Serializable with Logging {
val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
- new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
+ // if information gain doesn't satisfy minimum information gain,
+ // then this split is invalid, return invalid information gain stats.
+ if (gain < metadata.minInfoGain) {
+ return InformationGainStats.invalidInformationGainStats
+ }
+
+ new InformationGainStats(gain, impurity, leftImpurity, rightImpurity)
+ }
+
+ /**
+ * Calculate predict value for current node, given stats of any split.
+ * Note that this function is called only once for each node.
+ * @param leftImpurityCalculator left node aggregates for a split
+ * @param rightImpurityCalculator right node aggregates for a node
+ * @return predict value for current node
+ */
+ private def calculatePredict(
+ leftImpurityCalculator: ImpurityCalculator,
+ rightImpurityCalculator: ImpurityCalculator): Predict = {
+ val parentNodeAgg = leftImpurityCalculator.copy
+ parentNodeAgg.add(rightImpurityCalculator)
+ val predict = parentNodeAgg.predict
+ val prob = parentNodeAgg.prob(predict)
+
+ new Predict(predict, prob)
}
/**
@@ -780,12 +804,15 @@ object DecisionTree extends Serializable with Logging {
nodeImpurity: Double,
level: Int,
metadata: DecisionTreeMetadata,
- splits: Array[Array[Split]]): (Split, InformationGainStats) = {
+ splits: Array[Array[Split]]): (Split, InformationGainStats, Predict) = {
logDebug("node impurity = " + nodeImpurity)
+ // calculate predict only once
+ var predict: Option[Predict] = None
+
// For each (feature, split), calculate the gain, and select the best (feature, split).
- Range(0, metadata.numFeatures).map { featureIndex =>
+ val (bestSplit, bestSplitStats) = Range(0, metadata.numFeatures).map { featureIndex =>
val numSplits = metadata.numSplits(featureIndex)
if (metadata.isContinuous(featureIndex)) {
// Cumulative sum (scanLeft) of bin statistics.
@@ -803,6 +830,7 @@ object DecisionTree extends Serializable with Logging {
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
+ predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
val gainStats =
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
(splitIdx, gainStats)
@@ -816,6 +844,7 @@ object DecisionTree extends Serializable with Logging {
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
+ predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
val gainStats =
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
(splitIndex, gainStats)
@@ -887,6 +916,7 @@ object DecisionTree extends Serializable with Logging {
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
+ predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
val gainStats =
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
(splitIndex, gainStats)
@@ -898,6 +928,10 @@ object DecisionTree extends Serializable with Logging {
(bestFeatureSplit, bestFeatureGainStats)
}
}.maxBy(_._2.gain)
+
+ require(predict.isDefined, "must calculate predict for each node")
+
+ (bestSplit, bestSplitStats, predict.get)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 23f74d5360..987fe632c9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -49,6 +49,13 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* k) implies the feature n is categorical with k categories 0,
* 1, 2, ... , k-1. It's important to note that features are
* zero-indexed.
+ * @param minInstancesPerNode Minimum number of instances each child must have after split.
+ * Default value is 1. If a split cause left or right child
+ * to have less than minInstancesPerNode,
+ * this split will not be considered as a valid split.
+ * @param minInfoGain Minimum information gain a split must get. Default value is 0.0.
+ * If a split has less information gain than minInfoGain,
+ * this split will not be considered as a valid split.
* @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
* 256 MB.
*/
@@ -61,6 +68,8 @@ class Strategy (
val maxBins: Int = 32,
val quantileCalculationStrategy: QuantileStrategy = Sort,
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
+ val minInstancesPerNode: Int = 1,
+ val minInfoGain: Double = 0.0,
val maxMemoryInMB: Int = 256) extends Serializable {
if (algo == Classification) {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index e95add7558..5ceaa8154d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -45,7 +45,9 @@ private[tree] class DecisionTreeMetadata(
val unorderedFeatures: Set[Int],
val numBins: Array[Int],
val impurity: Impurity,
- val quantileStrategy: QuantileStrategy) extends Serializable {
+ val quantileStrategy: QuantileStrategy,
+ val minInstancesPerNode: Int,
+ val minInfoGain: Double) extends Serializable {
def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)
@@ -127,7 +129,8 @@ private[tree] object DecisionTreeMetadata {
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
- strategy.impurity, strategy.quantileCalculationStrategy)
+ strategy.impurity, strategy.quantileCalculationStrategy,
+ strategy.minInstancesPerNode, strategy.minInfoGain)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index fb12298e0f..f3e2619bd8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -26,20 +26,26 @@ import org.apache.spark.annotation.DeveloperApi
* @param impurity current node impurity
* @param leftImpurity left node impurity
* @param rightImpurity right node impurity
- * @param predict predicted value
- * @param prob probability of the label (classification only)
*/
@DeveloperApi
class InformationGainStats(
val gain: Double,
val impurity: Double,
val leftImpurity: Double,
- val rightImpurity: Double,
- val predict: Double,
- val prob: Double = 0.0) extends Serializable {
+ val rightImpurity: Double) extends Serializable {
override def toString = {
- "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob = %f"
- .format(gain, impurity, leftImpurity, rightImpurity, predict, prob)
+ "gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
+ .format(gain, impurity, leftImpurity, rightImpurity)
}
}
+
+
+private[tree] object InformationGainStats {
+ /**
+ * An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to
+ * denote that current split doesn't satisfies minimum info gain or
+ * minimum number of instances per node.
+ */
+ val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
new file mode 100644
index 0000000000..6fac2be279
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * :: DeveloperApi ::
+ * Predicted value for a node
+ * @param predict predicted value
+ * @param prob probability of the label (classification only)
+ */
+@DeveloperApi
+private[tree] class Predict(
+ val predict: Double,
+ val prob: Double = 0.0) extends Serializable{
+
+ override def toString = {
+ "predict = %f, prob = %f".format(predict, prob)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
index 50fb48b40d..b7a85f5854 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
@@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree.model
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
+import org.apache.spark.mllib.tree.configuration.FeatureType
+import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
/**
* :: DeveloperApi ::
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 69482f2acb..fd8547c166 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
-import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node}
+import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node}
import org.apache.spark.mllib.util.LocalSparkContext
class DecisionTreeSuite extends FunSuite with LocalSparkContext {
@@ -279,9 +279,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(split.threshold === Double.MinValue)
val stats = bestSplits(0)._2
+ val predict = bestSplits(0)._3
assert(stats.gain > 0)
- assert(stats.predict === 1)
- assert(stats.prob === 0.6)
+ assert(predict.predict === 1)
+ assert(predict.prob === 0.6)
assert(stats.impurity > 0.2)
}
@@ -312,8 +313,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(split.threshold === Double.MinValue)
val stats = bestSplits(0)._2
+ val predict = bestSplits(0)._3.predict
assert(stats.gain > 0)
- assert(stats.predict === 0.6)
+ assert(predict === 0.6)
assert(stats.impurity > 0.2)
}
@@ -387,7 +389,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(0)._2.gain === 0)
assert(bestSplits(0)._2.leftImpurity === 0)
assert(bestSplits(0)._2.rightImpurity === 0)
- assert(bestSplits(0)._2.predict === 1)
+ assert(bestSplits(0)._3.predict === 1)
}
test("Binary classification stump with fixed label 0 for Entropy") {
@@ -414,7 +416,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(0)._2.gain === 0)
assert(bestSplits(0)._2.leftImpurity === 0)
assert(bestSplits(0)._2.rightImpurity === 0)
- assert(bestSplits(0)._2.predict === 0)
+ assert(bestSplits(0)._3.predict === 0)
}
test("Binary classification stump with fixed label 1 for Entropy") {
@@ -441,7 +443,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(0)._2.gain === 0)
assert(bestSplits(0)._2.leftImpurity === 0)
assert(bestSplits(0)._2.rightImpurity === 0)
- assert(bestSplits(0)._2.predict === 1)
+ assert(bestSplits(0)._3.predict === 1)
}
test("Second level node building with vs. without groups") {
@@ -490,7 +492,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity)
assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity)
assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity)
- assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict)
+ assert(bestSplits(i)._3.predict === bestSplitsWithGroups(i)._3.predict)
}
}
@@ -674,6 +676,91 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
validateClassifier(model, arr, 0.6)
}
+ test("split must satisfy min instances per node requirements") {
+ val arr = new Array[LabeledPoint](3)
+ arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
+ arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
+ arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))
+
+ val input = sc.parallelize(arr)
+ val strategy = new Strategy(algo = Classification, impurity = Gini,
+ maxDepth = 2, numClassesForClassification = 2, minInstancesPerNode = 2)
+
+ val model = DecisionTree.train(input, strategy)
+ assert(model.topNode.isLeaf)
+ assert(model.topNode.predict == 0.0)
+ val predicts = input.map(p => model.predict(p.features)).collect()
+ predicts.foreach { predict =>
+ assert(predict == 0.0)
+ }
+
+ // test for findBestSplits when no valid split can be found
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+ val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
+
+ assert(bestSplits.length == 1)
+ val bestInfoStats = bestSplits(0)._2
+ assert(bestInfoStats == InformationGainStats.invalidInformationGainStats)
+ }
+
+ test("don't choose split that doesn't satisfy min instance per node requirements") {
+ // if a split doesn't satisfy min instances per node requirements,
+ // this split is invalid, even though the information gain of split is large.
+ val arr = new Array[LabeledPoint](4)
+ arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0))
+ arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0, 1.0))
+ arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
+ arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
+
+ val input = sc.parallelize(arr)
+ val strategy = new Strategy(algo = Classification, impurity = Gini,
+ maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2),
+ numClassesForClassification = 2, minInstancesPerNode = 2)
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+ val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
+
+ assert(bestSplits.length == 1)
+ val bestSplit = bestSplits(0)._1
+ val bestSplitStats = bestSplits(0)._1
+ assert(bestSplit.feature == 1)
+ assert(bestSplitStats != InformationGainStats.invalidInformationGainStats)
+ }
+
+ test("split must satisfy min info gain requirements") {
+ val arr = new Array[LabeledPoint](3)
+ arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
+ arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
+ arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))
+
+ val input = sc.parallelize(arr)
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+ numClassesForClassification = 2, minInfoGain = 1.0)
+
+ val model = DecisionTree.train(input, strategy)
+ assert(model.topNode.isLeaf)
+ assert(model.topNode.predict == 0.0)
+ val predicts = input.map(p => model.predict(p.features)).collect()
+ predicts.foreach { predict =>
+ assert(predict == 0.0)
+ }
+
+ // test for findBestSplits when no valid split can be found
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+ val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
+
+ assert(bestSplits.length == 1)
+ val bestInfoStats = bestSplits(0)._2
+ assert(bestInfoStats == InformationGainStats.invalidInformationGainStats)
+ }
}
object DecisionTreeSuite {