aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-03-23 21:16:00 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-23 21:16:00 -0700
commitcf823bead18c5be86b36da59b4bbf935c4804d04 (patch)
tree7e48dd6b225e4f2ce670d6c5513215c914053194 /mllib/src/test
parentf42eaf42bdca8bc6f390f1f31ee60faa1662489b (diff)
downloadspark-cf823bead18c5be86b36da59b4bbf935c4804d04.tar.gz
spark-cf823bead18c5be86b36da59b4bbf935c4804d04.tar.bz2
spark-cf823bead18c5be86b36da59b4bbf935c4804d04.zip
[SPARK-12183][ML][MLLIB] Remove mllib tree implementation, and wrap spark.ml one
Primary change: * Removed spark.mllib.tree.DecisionTree implementation of tree and forest learning. * spark.mllib now calls the spark.ml implementation. * Moved unit tests (of tree learning internals) from spark.mllib to spark.ml as needed. ml.tree.DecisionTreeModel * Added toOld and made ```private[spark]```, implemented for Classifier and Regressor in subclasses. These methods now use OldInformationGainStats.invalidInformationGainStats for LeafNodes in order to mimic the spark.mllib implementation. ml.tree.Node * Added ```private[tree] def deepCopy```, used by unit tests Copied developer comments from spark.mllib implementation to spark.ml one. Moving unit tests * Tree learning internals were tested by spark.mllib.tree.DecisionTreeSuite, or spark.mllib.tree.RandomForestSuite. * Those tests were all moved to spark.ml.tree.impl.RandomForestSuite. The order in the file + the test names are the same, so you should be able to compare them by opening them in 2 windows side-by-side. * I made minimal changes to each test to allow it to run. Each test makes the same checks as before, except for a few removed assertions which were checking irrelevant values. * No new unit tests were added. * mllib.tree.DecisionTreeSuite: I removed some checks of splits and bins which were not relevant to the unit tests they were in. Those same split calculations were already being tested in other unit tests, for each dataset type. **Changes of behavior** (to be noted in SPARK-13448 once this PR is merged) * spark.ml.tree.impl.RandomForest: Rather than throwing an error when maxMemoryInMB is set to too small a value (to split any node), we now allow 1 node to be split, even if its memory requirements exceed maxMemoryInMB. This involved removing the maxMemoryPerNode check in RandomForest.run, as well as modifying selectNodesToSplit(). Once this PR is merged, I will note the change of behavior on SPARK-13448. * spark.mllib.tree.DecisionTree: When a tree only has one node (root = leaf node), the "stats" field will now be empty, rather than being set to InformationGainStats.invalidInformationGainStats. This does not remove information from the tree, and it will save a bit of storage. Author: Joseph K. Bradley <joseph@databricks.com> Closes #11855 from jkbradley/remove-mllib-tree-impl.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala418
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala486
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala83
3 files changed, 435 insertions, 552 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index 9d922291a6..361366fde7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -17,11 +17,17 @@
package org.apache.spark.ml.tree.impl
+import scala.collection.mutable
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
-import org.apache.spark.ml.tree.{ContinuousSplit, DecisionTreeModel, LeafNode, Node}
+import org.apache.spark.ml.tree._
import org.apache.spark.mllib.linalg.{Vector, Vectors}
-import org.apache.spark.mllib.tree.impurity.GiniCalculator
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata}
+import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.collection.OpenHashMap
@@ -33,6 +39,414 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
import RandomForestSuite.mapToVec
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests for split calculation
+ /////////////////////////////////////////////////////////////////////////////
+
+ test("Binary classification with continuous features: split calculation") {
+ val arr = OldDTSuite.generateOrderedLabeledPointsWithLabel1()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2, 100)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ val splits = RandomForest.findSplits(rdd, metadata, seed = 42)
+ assert(splits.length === 2)
+ assert(splits(0).length === 99)
+ }
+
+ test("Binary classification with binary (ordered) categorical features: split calculation") {
+ val arr = OldDTSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2,
+ maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2))
+
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val splits = RandomForest.findSplits(rdd, metadata, seed = 42)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
+ assert(splits.length === 2)
+ // no splits pre-computed for ordered categorical features
+ assert(splits(0).length === 0)
+ }
+
+ test("Binary classification with 3-ary (ordered) categorical features," +
+ " with no samples for one category: split calculation") {
+ val arr = OldDTSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2,
+ maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
+
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
+ val splits = RandomForest.findSplits(rdd, metadata, seed = 42)
+ assert(splits.length === 2)
+ // no splits pre-computed for ordered categorical features
+ assert(splits(0).length === 0)
+ }
+
+ test("find splits for a continuous feature") {
+ // find splits for normal case
+ {
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ Map(), Set(),
+ Array(6), Gini, QuantileStrategy.Sort,
+ 0, 0, 0.0, 0, 0
+ )
+ val featureSamples = Array.fill(200000)(math.random)
+ val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+ assert(splits.length === 5)
+ assert(fakeMetadata.numSplits(0) === 5)
+ assert(fakeMetadata.numBins(0) === 6)
+ // check returned splits are distinct
+ assert(splits.distinct.length === splits.length)
+ }
+
+ // find splits should not return identical splits
+ // when there are not enough split candidates, reduce the number of splits in metadata
+ {
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ Map(), Set(),
+ Array(5), Gini, QuantileStrategy.Sort,
+ 0, 0, 0.0, 0, 0
+ )
+ val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
+ val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+ assert(splits.length === 3)
+ // check returned splits are distinct
+ assert(splits.distinct.length === splits.length)
+ }
+
+ // find splits when most samples close to the minimum
+ {
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ Map(), Set(),
+ Array(3), Gini, QuantileStrategy.Sort,
+ 0, 0, 0.0, 0, 0
+ )
+ val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
+ val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+ assert(splits.length === 2)
+ assert(splits(0) === 2.0)
+ assert(splits(1) === 3.0)
+ }
+
+ // find splits when most samples close to the maximum
+ {
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ Map(), Set(),
+ Array(3), Gini, QuantileStrategy.Sort,
+ 0, 0, 0.0, 0, 0
+ )
+ val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
+ val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+ assert(splits.length === 1)
+ assert(splits(0) === 1.0)
+ }
+ }
+
+ test("Multiclass classification with unordered categorical features: split calculations") {
+ val arr = OldDTSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new OldStrategy(
+ OldAlgo.Classification,
+ Gini,
+ maxDepth = 2,
+ numClasses = 100,
+ maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
+
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(metadata.isUnordered(featureIndex = 0))
+ assert(metadata.isUnordered(featureIndex = 1))
+ val splits = RandomForest.findSplits(rdd, metadata, seed = 42)
+ assert(splits.length === 2)
+ assert(splits(0).length === 3)
+ assert(metadata.numSplits(0) === 3)
+ assert(metadata.numBins(0) === 3)
+ assert(metadata.numSplits(1) === 3)
+ assert(metadata.numBins(1) === 3)
+
+ // Expecting 2^2 - 1 = 3 splits per feature
+ def checkCategoricalSplit(s: Split, featureIndex: Int, leftCategories: Array[Double]): Unit = {
+ assert(s.featureIndex === featureIndex)
+ assert(s.isInstanceOf[CategoricalSplit])
+ val s0 = s.asInstanceOf[CategoricalSplit]
+ assert(s0.leftCategories === leftCategories)
+ assert(s0.numCategories === 3) // for this unit test
+ }
+ // Feature 0
+ checkCategoricalSplit(splits(0)(0), 0, Array(0.0))
+ checkCategoricalSplit(splits(0)(1), 0, Array(1.0))
+ checkCategoricalSplit(splits(0)(2), 0, Array(0.0, 1.0))
+ // Feature 1
+ checkCategoricalSplit(splits(1)(0), 1, Array(0.0))
+ checkCategoricalSplit(splits(1)(1), 1, Array(1.0))
+ checkCategoricalSplit(splits(1)(2), 1, Array(0.0, 1.0))
+ }
+
+ test("Multiclass classification with ordered categorical features: split calculations") {
+ val arr = OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
+ assert(arr.length === 3000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 100,
+ maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
+ // 2^(10-1) - 1 > 100, so categorical features will be ordered
+
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
+ val splits = RandomForest.findSplits(rdd, metadata, seed = 42)
+ assert(splits.length === 2)
+ // no splits pre-computed for ordered categorical features
+ assert(splits(0).length === 0)
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests of other algorithm internals
+ /////////////////////////////////////////////////////////////////////////////
+
+ test("extract categories from a number for multiclass classification") {
+ val l = RandomForest.extractMultiClassCategories(13, 10)
+ assert(l.length === 3)
+ assert(Seq(3.0, 2.0, 0.0) === l)
+ }
+
+ test("Avoid aggregation on the last level") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
+ val input = sc.parallelize(arr)
+
+ val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1,
+ numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+ val splits = RandomForest.findSplits(input, metadata, seed = 42)
+
+ val treeInput = TreePoint.convertToTreeRDD(input, splits, metadata)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, withReplacement = false)
+
+ val topNode = LearningNode.emptyNode(nodeIndex = 1)
+ assert(topNode.isLeaf === false)
+ assert(topNode.stats === null)
+
+ val nodesForGroup = Map((0, Array(topNode)))
+ val treeToNodeToIndexInfo = Map((0, Map(
+ (topNode.id, new RandomForest.NodeIndexInfo(0, None))
+ )))
+ val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
+ RandomForest.findBestSplits(baggedInput, metadata, Array(topNode),
+ nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue)
+
+ // don't enqueue leaf nodes into node queue
+ assert(nodeQueue.isEmpty)
+
+ // set impurity and predict for topNode
+ assert(topNode.stats !== null)
+ assert(topNode.stats.impurity > 0.0)
+
+ // set impurity and predict for child nodes
+ assert(topNode.leftChild.get.toNode.prediction === 0.0)
+ assert(topNode.rightChild.get.toNode.prediction === 1.0)
+ assert(topNode.leftChild.get.stats.impurity === 0.0)
+ assert(topNode.rightChild.get.stats.impurity === 0.0)
+ }
+
+ test("Avoid aggregation if impurity is 0.0") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
+ val input = sc.parallelize(arr)
+
+ val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 5,
+ numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+ val splits = RandomForest.findSplits(input, metadata, seed = 42)
+
+ val treeInput = TreePoint.convertToTreeRDD(input, splits, metadata)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, withReplacement = false)
+
+ val topNode = LearningNode.emptyNode(nodeIndex = 1)
+ assert(topNode.isLeaf === false)
+ assert(topNode.stats === null)
+
+ val nodesForGroup = Map((0, Array(topNode)))
+ val treeToNodeToIndexInfo = Map((0, Map(
+ (topNode.id, new RandomForest.NodeIndexInfo(0, None))
+ )))
+ val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
+ RandomForest.findBestSplits(baggedInput, metadata, Array(topNode),
+ nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue)
+
+ // don't enqueue a node into node queue if its impurity is 0.0
+ assert(nodeQueue.isEmpty)
+
+ // set impurity and predict for topNode
+ assert(topNode.stats !== null)
+ assert(topNode.stats.impurity > 0.0)
+
+ // set impurity and predict for child nodes
+ assert(topNode.leftChild.get.toNode.prediction === 0.0)
+ assert(topNode.rightChild.get.toNode.prediction === 1.0)
+ assert(topNode.leftChild.get.stats.impurity === 0.0)
+ assert(topNode.rightChild.get.stats.impurity === 0.0)
+ }
+
+ test("Use soft prediction for binary classification with ordered categorical features") {
+ // The following dataset is set up such that the best split is {1} vs. {0, 2}.
+ // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0)),
+ LabeledPoint(1.0, Vectors.dense(2.0)))
+ val input = sc.parallelize(arr)
+
+ // Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
+ val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1,
+ numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
+
+ val model = RandomForest.run(input, strategy, numTrees = 1, featureSubsetStrategy = "all",
+ seed = 42).head
+ model.rootNode match {
+ case n: InternalNode => n.split match {
+ case s: CategoricalSplit =>
+ assert(s.leftCategories === Array(1.0))
+ }
+ }
+ }
+
+ test("Second level node building with vs. without groups") {
+ val arr = OldDTSuite.generateOrderedLabeledPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ // For tree with 1 group
+ val strategy1 =
+ new OldStrategy(OldAlgo.Classification, Entropy, 3, 2, 100, maxMemoryInMB = 1000)
+ // For tree with multiple groups
+ val strategy2 =
+ new OldStrategy(OldAlgo.Classification, Entropy, 3, 2, 100, maxMemoryInMB = 0)
+
+ val tree1 = RandomForest.run(rdd, strategy1, numTrees = 1, featureSubsetStrategy = "all",
+ seed = 42).head
+ val tree2 = RandomForest.run(rdd, strategy2, numTrees = 1, featureSubsetStrategy = "all",
+ seed = 42).head
+
+ def getChildren(rootNode: Node): Array[InternalNode] = rootNode match {
+ case n: InternalNode =>
+ assert(n.leftChild.isInstanceOf[InternalNode])
+ assert(n.rightChild.isInstanceOf[InternalNode])
+ Array(n.leftChild.asInstanceOf[InternalNode], n.rightChild.asInstanceOf[InternalNode])
+ }
+
+ // Single group second level tree construction.
+ val children1 = getChildren(tree1.rootNode)
+ val children2 = getChildren(tree2.rootNode)
+
+ // Verify whether the splits obtained using single group and multiple group level
+ // construction strategies are the same.
+ for (i <- 0 until 2) {
+ assert(children1(i).gain > 0)
+ assert(children2(i).gain > 0)
+ assert(children1(i).split === children2(i).split)
+ assert(children1(i).impurity === children2(i).impurity)
+ assert(children1(i).impurityStats.stats === children2(i).impurityStats.stats)
+ assert(children1(i).leftChild.impurity === children2(i).leftChild.impurity)
+ assert(children1(i).rightChild.impurity === children2(i).rightChild.impurity)
+ assert(children1(i).prediction === children2(i).prediction)
+ }
+ }
+
+ def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: OldStrategy) {
+ val numFeatures = 50
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000)
+ val rdd = sc.parallelize(arr)
+
+ // Select feature subset for top nodes. Return true if OK.
+ def checkFeatureSubsetStrategy(
+ numTrees: Int,
+ featureSubsetStrategy: String,
+ numFeaturesPerNode: Int): Unit = {
+ val seeds = Array(123, 5354, 230, 349867, 23987)
+ val maxMemoryUsage: Long = 128 * 1024L * 1024L
+ val metadata =
+ DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees, featureSubsetStrategy)
+ seeds.foreach { seed =>
+ val failString = s"Failed on test with:" +
+ s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," +
+ s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed"
+ val nodeQueue = new mutable.Queue[(Int, LearningNode)]()
+ val topNodes: Array[LearningNode] = new Array[LearningNode](numTrees)
+ Range(0, numTrees).foreach { treeIndex =>
+ topNodes(treeIndex) = LearningNode.emptyNode(nodeIndex = 1)
+ nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))
+ }
+ val rng = new scala.util.Random(seed = seed)
+ val (nodesForGroup: Map[Int, Array[LearningNode]],
+ treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) =
+ RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
+
+ assert(nodesForGroup.size === numTrees, failString)
+ assert(nodesForGroup.values.forall(_.length == 1), failString) // 1 node per tree
+
+ if (numFeaturesPerNode == numFeatures) {
+ // featureSubset values should all be None
+ assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)),
+ failString)
+ } else {
+ // Check number of features.
+ assert(treeToNodeToIndexInfo.values.forall(_.values.forall(
+ _.featureSubset.get.length === numFeaturesPerNode)), failString)
+ }
+ }
+ }
+
+ checkFeatureSubsetStrategy(numTrees = 1, "auto", numFeatures)
+ checkFeatureSubsetStrategy(numTrees = 1, "all", numFeatures)
+ checkFeatureSubsetStrategy(numTrees = 1, "sqrt", math.sqrt(numFeatures).ceil.toInt)
+ checkFeatureSubsetStrategy(numTrees = 1, "log2",
+ (math.log(numFeatures) / math.log(2)).ceil.toInt)
+ checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt)
+
+ checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures)
+ checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt)
+ checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt)
+ checkFeatureSubsetStrategy(numTrees = 2, "log2",
+ (math.log(numFeatures) / math.log(2)).ceil.toInt)
+ checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt)
+ }
+
+ test("Binary classification with continuous features: subsampling features") {
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 2,
+ numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+ binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
+ }
+
+ test("Binary classification with continuous features and node Id cache: subsampling features") {
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 2,
+ numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
+ useNodeIdCache = true)
+ binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
+ }
+
test("computeFeatureImportance, featureImportances") {
/* Build tree for testing, with this structure:
grandParent
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 89b64fce96..bb1041b109 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -18,430 +18,23 @@
package org.apache.spark.mllib.tree
import scala.collection.JavaConverters._
-import scala.collection.mutable
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
-import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint}
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model._
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
/////////////////////////////////////////////////////////////////////////////
- // Tests examining individual elements of training
- /////////////////////////////////////////////////////////////////////////////
-
- test("Binary classification with continuous features: split and bin calculation") {
- val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
- assert(arr.length === 1000)
- val rdd = sc.parallelize(arr)
- val strategy = new Strategy(Classification, Gini, 3, 2, 100)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
- assert(!metadata.isUnordered(featureIndex = 0))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(splits.length === 2)
- assert(bins.length === 2)
- assert(splits(0).length === 99)
- assert(bins(0).length === 100)
- }
-
- test("Binary classification with binary (ordered) categorical features:" +
- " split and bin calculation") {
- val arr = DecisionTreeSuite.generateCategoricalDataPoints()
- assert(arr.length === 1000)
- val rdd = sc.parallelize(arr)
- val strategy = new Strategy(
- Classification,
- Gini,
- maxDepth = 2,
- numClasses = 2,
- maxBins = 100,
- categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2))
-
- val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(!metadata.isUnordered(featureIndex = 0))
- assert(!metadata.isUnordered(featureIndex = 1))
- assert(splits.length === 2)
- assert(bins.length === 2)
- // no bins or splits pre-computed for ordered categorical features
- assert(splits(0).length === 0)
- assert(bins(0).length === 0)
- }
-
- test("Binary classification with 3-ary (ordered) categorical features," +
- " with no samples for one category") {
- val arr = DecisionTreeSuite.generateCategoricalDataPoints()
- assert(arr.length === 1000)
- val rdd = sc.parallelize(arr)
- val strategy = new Strategy(
- Classification,
- Gini,
- maxDepth = 2,
- numClasses = 2,
- maxBins = 100,
- categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
-
- val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
- assert(!metadata.isUnordered(featureIndex = 0))
- assert(!metadata.isUnordered(featureIndex = 1))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(splits.length === 2)
- assert(bins.length === 2)
- // no bins or splits pre-computed for ordered categorical features
- assert(splits(0).length === 0)
- assert(bins(0).length === 0)
- }
-
- test("extract categories from a number for multiclass classification") {
- val l = DecisionTree.extractMultiClassCategories(13, 10)
- assert(l.length === 3)
- assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq)
- }
-
- test("find splits for a continuous feature") {
- // find splits for normal case
- {
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
- Map(), Set(),
- Array(6), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
- )
- val featureSamples = Array.fill(200000)(math.random)
- val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
- assert(splits.length === 5)
- assert(fakeMetadata.numSplits(0) === 5)
- assert(fakeMetadata.numBins(0) === 6)
- // check returned splits are distinct
- assert(splits.distinct.length === splits.length)
- }
-
- // find splits should not return identical splits
- // when there are not enough split candidates, reduce the number of splits in metadata
- {
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
- Map(), Set(),
- Array(5), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
- )
- val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
- val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
- assert(splits.length === 3)
- // check returned splits are distinct
- assert(splits.distinct.length === splits.length)
- }
-
- // find splits when most samples close to the minimum
- {
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
- Map(), Set(),
- Array(3), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
- )
- val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
- val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
- assert(splits.length === 2)
- assert(splits(0) === 2.0)
- assert(splits(1) === 3.0)
- }
-
- // find splits when most samples close to the maximum
- {
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
- Map(), Set(),
- Array(3), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
- )
- val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
- val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
- assert(splits.length === 1)
- assert(splits(0) === 1.0)
- }
- }
-
- test("Multiclass classification with unordered categorical features:" +
- " split and bin calculations") {
- val arr = DecisionTreeSuite.generateCategoricalDataPoints()
- assert(arr.length === 1000)
- val rdd = sc.parallelize(arr)
- val strategy = new Strategy(
- Classification,
- Gini,
- maxDepth = 2,
- numClasses = 100,
- maxBins = 100,
- categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
-
- val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
- assert(metadata.isUnordered(featureIndex = 0))
- assert(metadata.isUnordered(featureIndex = 1))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(splits.length === 2)
- assert(bins.length === 2)
- assert(splits(0).length === 3)
- assert(bins(0).length === 0)
- assert(metadata.numSplits(0) === 3)
- assert(metadata.numBins(0) === 3)
- assert(metadata.numSplits(1) === 3)
- assert(metadata.numBins(1) === 3)
-
- // Expecting 2^2 - 1 = 3 bins/splits
- assert(splits(0)(0).feature === 0)
- assert(splits(0)(0).threshold === Double.MinValue)
- assert(splits(0)(0).featureType === Categorical)
- assert(splits(0)(0).categories.length === 1)
- assert(splits(0)(0).categories.contains(0.0))
- assert(splits(1)(0).feature === 1)
- assert(splits(1)(0).threshold === Double.MinValue)
- assert(splits(1)(0).featureType === Categorical)
- assert(splits(1)(0).categories.length === 1)
- assert(splits(1)(0).categories.contains(0.0))
-
- assert(splits(0)(1).feature === 0)
- assert(splits(0)(1).threshold === Double.MinValue)
- assert(splits(0)(1).featureType === Categorical)
- assert(splits(0)(1).categories.length === 1)
- assert(splits(0)(1).categories.contains(1.0))
- assert(splits(1)(1).feature === 1)
- assert(splits(1)(1).threshold === Double.MinValue)
- assert(splits(1)(1).featureType === Categorical)
- assert(splits(1)(1).categories.length === 1)
- assert(splits(1)(1).categories.contains(1.0))
-
- assert(splits(0)(2).feature === 0)
- assert(splits(0)(2).threshold === Double.MinValue)
- assert(splits(0)(2).featureType === Categorical)
- assert(splits(0)(2).categories.length === 2)
- assert(splits(0)(2).categories.contains(0.0))
- assert(splits(0)(2).categories.contains(1.0))
- assert(splits(1)(2).feature === 1)
- assert(splits(1)(2).threshold === Double.MinValue)
- assert(splits(1)(2).featureType === Categorical)
- assert(splits(1)(2).categories.length === 2)
- assert(splits(1)(2).categories.contains(0.0))
- assert(splits(1)(2).categories.contains(1.0))
-
- }
-
- test("Multiclass classification with ordered categorical features: split and bin calculations") {
- val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
- assert(arr.length === 3000)
- val rdd = sc.parallelize(arr)
- val strategy = new Strategy(
- Classification,
- Gini,
- maxDepth = 2,
- numClasses = 100,
- maxBins = 100,
- categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
- // 2^(10-1) - 1 > 100, so categorical features will be ordered
-
- val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
- assert(!metadata.isUnordered(featureIndex = 0))
- assert(!metadata.isUnordered(featureIndex = 1))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(splits.length === 2)
- assert(bins.length === 2)
- // no bins or splits pre-computed for ordered categorical features
- assert(splits(0).length === 0)
- assert(bins(0).length === 0)
- }
-
- test("Avoid aggregation on the last level") {
- val arr = Array(
- LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)),
- LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
- LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
- LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
- val input = sc.parallelize(arr)
-
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
- numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
- val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
-
- val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
-
- val topNode = Node.emptyNode(nodeIndex = 1)
- assert(topNode.predict.predict === Double.MinValue)
- assert(topNode.impurity === -1.0)
- assert(topNode.isLeaf === false)
-
- val nodesForGroup = Map((0, Array(topNode)))
- val treeToNodeToIndexInfo = Map((0, Map(
- (topNode.id, new RandomForest.NodeIndexInfo(0, None))
- )))
- val nodeQueue = new mutable.Queue[(Int, Node)]()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
- nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
-
- // don't enqueue leaf nodes into node queue
- assert(nodeQueue.isEmpty)
-
- // set impurity and predict for topNode
- assert(topNode.predict.predict !== Double.MinValue)
- assert(topNode.impurity !== -1.0)
-
- // set impurity and predict for child nodes
- assert(topNode.leftNode.get.predict.predict === 0.0)
- assert(topNode.rightNode.get.predict.predict === 1.0)
- assert(topNode.leftNode.get.impurity === 0.0)
- assert(topNode.rightNode.get.impurity === 0.0)
- }
-
- test("Avoid aggregation if impurity is 0.0") {
- val arr = Array(
- LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)),
- LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
- LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
- LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
- val input = sc.parallelize(arr)
-
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
- numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
- val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
-
- val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
-
- val topNode = Node.emptyNode(nodeIndex = 1)
- assert(topNode.predict.predict === Double.MinValue)
- assert(topNode.impurity === -1.0)
- assert(topNode.isLeaf === false)
-
- val nodesForGroup = Map((0, Array(topNode)))
- val treeToNodeToIndexInfo = Map((0, Map(
- (topNode.id, new RandomForest.NodeIndexInfo(0, None))
- )))
- val nodeQueue = new mutable.Queue[(Int, Node)]()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
- nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
-
- // don't enqueue a node into node queue if its impurity is 0.0
- assert(nodeQueue.isEmpty)
-
- // set impurity and predict for topNode
- assert(topNode.predict.predict !== Double.MinValue)
- assert(topNode.impurity !== -1.0)
-
- // set impurity and predict for child nodes
- assert(topNode.leftNode.get.predict.predict === 0.0)
- assert(topNode.rightNode.get.predict.predict === 1.0)
- assert(topNode.leftNode.get.impurity === 0.0)
- assert(topNode.rightNode.get.impurity === 0.0)
- }
-
- test("Use soft prediction for binary classification with ordered categorical features") {
- // The following dataset is set up such that the best split is {1} vs. {0, 2}.
- // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
- val arr = Array(
- LabeledPoint(0.0, Vectors.dense(0.0)),
- LabeledPoint(0.0, Vectors.dense(0.0)),
- LabeledPoint(0.0, Vectors.dense(0.0)),
- LabeledPoint(1.0, Vectors.dense(0.0)),
- LabeledPoint(0.0, Vectors.dense(1.0)),
- LabeledPoint(0.0, Vectors.dense(1.0)),
- LabeledPoint(0.0, Vectors.dense(1.0)),
- LabeledPoint(0.0, Vectors.dense(1.0)),
- LabeledPoint(0.0, Vectors.dense(2.0)),
- LabeledPoint(0.0, Vectors.dense(2.0)),
- LabeledPoint(0.0, Vectors.dense(2.0)),
- LabeledPoint(1.0, Vectors.dense(2.0)))
- val input = sc.parallelize(arr)
-
- // Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
- numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
-
- val model = new DecisionTree(strategy).run(input)
- model.topNode.split.get match {
- case Split(_, _, _, categories: List[Double]) =>
- assert(categories === List(1.0))
- }
- }
-
- test("Second level node building with vs. without groups") {
- val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
- assert(arr.length === 1000)
- val rdd = sc.parallelize(arr)
- val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(splits.length === 2)
- assert(splits(0).length === 99)
- assert(bins.length === 2)
- assert(bins(0).length === 100)
-
- // Train a 1-node model
- val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1,
- numClasses = 2, maxBins = 100)
- val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
- val rootNode1 = modelOneNode.topNode.deepCopy()
- val rootNode2 = modelOneNode.topNode.deepCopy()
- assert(rootNode1.leftNode.nonEmpty)
- assert(rootNode1.rightNode.nonEmpty)
-
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
-
- // Single group second level tree construction.
- val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get)))
- val treeToNodeToIndexInfo = Map((0, Map(
- (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)),
- (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None)))))
- val nodeQueue = new mutable.Queue[(Int, Node)]()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1),
- nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
- val children1 = new Array[Node](2)
- children1(0) = rootNode1.leftNode.get
- children1(1) = rootNode1.rightNode.get
-
- // Train one second-level node at a time.
- val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get)))
- val treeToNodeToIndexInfoA = Map((0, Map(
- (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
- nodeQueue.clear()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
- nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue)
- val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get)))
- val treeToNodeToIndexInfoB = Map((0, Map(
- (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
- nodeQueue.clear()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
- nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue)
- val children2 = new Array[Node](2)
- children2(0) = rootNode2.leftNode.get
- children2(1) = rootNode2.rightNode.get
-
- // Verify whether the splits obtained using single group and multiple group level
- // construction strategies are the same.
- for (i <- 0 until 2) {
- assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0)
- assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0)
- assert(children1(i).split === children2(i).split)
- assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty)
- val stats1 = children1(i).stats.get
- val stats2 = children2(i).stats.get
- assert(stats1.gain === stats2.gain)
- assert(stats1.impurity === stats2.impurity)
- assert(stats1.leftImpurity === stats2.leftImpurity)
- assert(stats1.rightImpurity === stats2.rightImpurity)
- assert(children1(i).predict.predict === children2(i).predict.predict)
- }
- }
-
- /////////////////////////////////////////////////////////////////////////////
// Tests calling train()
/////////////////////////////////////////////////////////////////////////////
@@ -457,22 +50,11 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
- val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
- assert(!metadata.isUnordered(featureIndex = 0))
- assert(!metadata.isUnordered(featureIndex = 1))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(splits.length === 2)
- assert(bins.length === 2)
- // no bins or splits pre-computed for ordered categorical features
- assert(splits(0).length === 0)
- assert(bins(0).length === 0)
-
val rootNode = DecisionTree.train(rdd, strategy).topNode
val split = rootNode.split.get
assert(split.categories === List(1.0))
assert(split.featureType === Categorical)
- assert(split.threshold === Double.MinValue)
val stats = rootNode.stats.get
assert(stats.gain > 0)
@@ -501,7 +83,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(split.categories.length === 1)
assert(split.categories.contains(1.0))
assert(split.featureType === Categorical)
- assert(split.threshold === Double.MinValue)
val stats = rootNode.stats.get
assert(stats.gain > 0)
@@ -539,18 +120,11 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(splits.length === 2)
- assert(splits(0).length === 99)
- assert(bins.length === 2)
- assert(bins(0).length === 100)
-
val rootNode = DecisionTree.train(rdd, strategy).topNode
- val stats = rootNode.stats.get
- assert(stats.gain === 0)
- assert(stats.leftImpurity === 0)
- assert(stats.rightImpurity === 0)
+ assert(rootNode.impurity === 0)
+ assert(rootNode.stats.isEmpty)
+ assert(rootNode.predict.predict === 0)
}
test("Binary classification stump with fixed label 1 for Gini") {
@@ -563,18 +137,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(splits.length === 2)
- assert(splits(0).length === 99)
- assert(bins.length === 2)
- assert(bins(0).length === 100)
-
val rootNode = DecisionTree.train(rdd, strategy).topNode
- val stats = rootNode.stats.get
- assert(stats.gain === 0)
- assert(stats.leftImpurity === 0)
- assert(stats.rightImpurity === 0)
+ assert(rootNode.impurity === 0)
+ assert(rootNode.stats.isEmpty)
assert(rootNode.predict.predict === 1)
}
@@ -588,18 +154,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(splits.length === 2)
- assert(splits(0).length === 99)
- assert(bins.length === 2)
- assert(bins(0).length === 100)
-
val rootNode = DecisionTree.train(rdd, strategy).topNode
- val stats = rootNode.stats.get
- assert(stats.gain === 0)
- assert(stats.leftImpurity === 0)
- assert(stats.rightImpurity === 0)
+ assert(rootNode.impurity === 0)
+ assert(rootNode.stats.isEmpty)
assert(rootNode.predict.predict === 0)
}
@@ -613,18 +171,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(splits.length === 2)
- assert(splits(0).length === 99)
- assert(bins.length === 2)
- assert(bins(0).length === 100)
-
val rootNode = DecisionTree.train(rdd, strategy).topNode
- val stats = rootNode.stats.get
- assert(stats.gain === 0)
- assert(stats.leftImpurity === 0)
- assert(stats.rightImpurity === 0)
+ assert(rootNode.impurity === 0)
+ assert(rootNode.stats.isEmpty)
assert(rootNode.predict.predict === 1)
}
@@ -718,7 +268,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClasses = 3, maxBins = 100)
assert(strategy.isMulticlassClassification)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
val model = DecisionTree.train(rdd, strategy)
DecisionTreeSuite.validateClassifier(model, arr, 0.9)
@@ -807,8 +356,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
// test when no valid split can be found
val rootNode = model.topNode
- val gain = rootNode.stats.get
- assert(gain == InformationGainStats.invalidInformationGainStats)
+ assert(rootNode.stats.isEmpty)
}
test("do not choose split that does not satisfy min instance per node requirements") {
@@ -828,9 +376,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val rootNode = DecisionTree.train(rdd, strategy).topNode
val split = rootNode.split.get
- val gain = rootNode.stats.get
+ val gainStats = rootNode.stats.get
assert(split.feature == 1)
- assert(gain != InformationGainStats.invalidInformationGainStats)
+ assert(gainStats.gain >= 0)
+ assert(gainStats.impurity >= 0)
}
test("split must satisfy min info gain requirements") {
@@ -852,10 +401,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
}
// test when no valid split can be found
- val rootNode = model.topNode
-
- val gain = rootNode.stats.get
- assert(gain == InformationGainStats.invalidInformationGainStats)
+ assert(model.topNode.stats.isEmpty)
}
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index c72fc9bb4f..bec61ba6a0 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -17,16 +17,13 @@
package org.apache.spark.mllib.tree
-import scala.collection.mutable
-
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
import org.apache.spark.mllib.tree.impurity.{Gini, Variance}
-import org.apache.spark.mllib.tree.model.{Node, RandomForestModel}
+import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.util.Utils
@@ -42,7 +39,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees,
featureSubsetStrategy = "auto", seed = 123)
- assert(rf.trees.size === 1)
+ assert(rf.trees.length === 1)
val rfTree = rf.trees(0)
val dt = DecisionTree.train(rdd, strategy)
@@ -78,7 +75,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees,
featureSubsetStrategy = "auto", seed = 123)
- assert(rf.trees.size === 1)
+ assert(rf.trees.length === 1)
val rfTree = rf.trees(0)
val dt = DecisionTree.train(rdd, strategy)
@@ -108,80 +105,6 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
regressionTestWithContinuousFeatures(strategy)
}
- def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: Strategy) {
- val numFeatures = 50
- val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000)
- val rdd = sc.parallelize(arr)
-
- // Select feature subset for top nodes. Return true if OK.
- def checkFeatureSubsetStrategy(
- numTrees: Int,
- featureSubsetStrategy: String,
- numFeaturesPerNode: Int): Unit = {
- val seeds = Array(123, 5354, 230, 349867, 23987)
- val maxMemoryUsage: Long = 128 * 1024L * 1024L
- val metadata =
- DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees, featureSubsetStrategy)
- seeds.foreach { seed =>
- val failString = s"Failed on test with:" +
- s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," +
- s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed"
- val nodeQueue = new mutable.Queue[(Int, Node)]()
- val topNodes: Array[Node] = new Array[Node](numTrees)
- Range(0, numTrees).foreach { treeIndex =>
- topNodes(treeIndex) = Node.emptyNode(nodeIndex = 1)
- nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))
- }
- val rng = new scala.util.Random(seed = seed)
- val (nodesForGroup: Map[Int, Array[Node]],
- treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) =
- RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
-
- assert(nodesForGroup.size === numTrees, failString)
- assert(nodesForGroup.values.forall(_.size == 1), failString) // 1 node per tree
-
- if (numFeaturesPerNode == numFeatures) {
- // featureSubset values should all be None
- assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)),
- failString)
- } else {
- // Check number of features.
- assert(treeToNodeToIndexInfo.values.forall(_.values.forall(
- _.featureSubset.get.size === numFeaturesPerNode)), failString)
- }
- }
- }
-
- checkFeatureSubsetStrategy(numTrees = 1, "auto", numFeatures)
- checkFeatureSubsetStrategy(numTrees = 1, "all", numFeatures)
- checkFeatureSubsetStrategy(numTrees = 1, "sqrt", math.sqrt(numFeatures).ceil.toInt)
- checkFeatureSubsetStrategy(numTrees = 1, "log2",
- (math.log(numFeatures) / math.log(2)).ceil.toInt)
- checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt)
-
- checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures)
- checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt)
- checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt)
- checkFeatureSubsetStrategy(numTrees = 2, "log2",
- (math.log(numFeatures) / math.log(2)).ceil.toInt)
- checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt)
- }
-
- test("Binary classification with continuous features: subsampling features") {
- val categoricalFeaturesInfo = Map.empty[Int, Int]
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
- numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
- binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
- }
-
- test("Binary classification with continuous features and node Id cache: subsampling features") {
- val categoricalFeaturesInfo = Map.empty[Int, Int]
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
- numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
- useNodeIdCache = true)
- binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
- }
-
test("alternating categorical and continuous features with multiclass labels to test indexing") {
val arr = new Array[LabeledPoint](4)
arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0))