aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph.kurata.bradley@gmail.com>2014-08-16 23:53:14 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-16 23:53:14 -0700
commit73ab7f141c205df277c6ac19252e590d6806c41f (patch)
tree004324c406ca843bd58b8ff73f5a3cb6ab2c7437 /mllib/src/test
parentfbad72288d8b6e641b00417a544cae6e8bfef2d7 (diff)
downloadspark-73ab7f141c205df277c6ac19252e590d6806c41f.tar.gz
spark-73ab7f141c205df277c6ac19252e590d6806c41f.tar.bz2
spark-73ab7f141c205df277c6ac19252e590d6806c41f.zip
[SPARK-3042] [mllib] DecisionTree Filter top-down instead of bottom-up
DecisionTree needs to match each example to a node at each iteration. It currently does this with a set of filters very inefficiently: For each example, it examines each node at the current level and traces up to the root to see if that example should be handled by that node. Fix: Filter top-down using the partly built tree itself. Major changes: * Eliminated Filter class, findBinsForLevel() method. * Set up node parent links in main loop over levels in train(). * Added predictNodeIndex() for filtering top-down. * Added DTMetadata class Other changes: * Pre-compute set of unorderedFeatures. Notes for following expected PR based on [https://issues.apache.org/jira/browse/SPARK-3043]: * The unorderedFeatures set will next be stored in a metadata structure to simplify function calls (to store other items such as the data in strategy). I've done initial tests indicating that this speeds things up, but am only now running large-scale ones. CC: mengxr manishamde chouqin Any comments are welcome---thanks! Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Closes #1975 from jkbradley/dt-opt2 and squashes the following commits: a0ed0da [Joseph K. Bradley] Renamed DTMetadata to DecisionTreeMetadata. Small doc updates. 3726d20 [Joseph K. Bradley] Small code improvements based on code review. ac0b9f8 [Joseph K. Bradley] Small updates based on code review. Main change: Now using << instead of math.pow. db0d773 [Joseph K. Bradley] scala style fix 6a38f48 [Joseph K. Bradley] Added DTMetadata class for cleaner code 931a3a7 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt2 797f68a [Joseph K. Bradley] Fixed DecisionTreeSuite bug for training second level. Needed to update treePointToNodeIndex with groupShift. f40381c [Joseph K. Bradley] Merge branch 'dt-opt1' into dt-opt2 5f2dec2 [Joseph K. Bradley] Fixed scalastyle issue in TreePoint 6b5651e [Joseph K. Bradley] Updates based on code review. 1 major change: persisting to memory + disk, not just memory. 2d2aaaf [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1 26d10dd [Joseph K. Bradley] Removed tree/model/Filter.scala since no longer used. Removed debugging println calls in DecisionTree.scala. 356daba [Joseph K. Bradley] Merge branch 'dt-opt1' into dt-opt2 430d782 [Joseph K. Bradley] Added more debug info on binning error. Added some docs. d036089 [Joseph K. Bradley] Print timing info to logDebug. e66f1b1 [Joseph K. Bradley] TreePoint * Updated doc * Made some methods private 8464a6e [Joseph K. Bradley] Moved TimeTracker to tree/impl/ in its own file, and cleaned it up. Removed debugging println calls from DecisionTree. Made TreePoint extend Serialiable a87e08f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1 c1565a5 [Joseph K. Bradley] Small DecisionTree updates: * Simplification: Updated calculateGainForSplit to take aggregates for a single (feature, split) pair. * Internal doc: findAggForOrderedFeatureClassification b914f3b [Joseph K. Bradley] DecisionTree optimization: eliminated filters + small changes b2ed1f3 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt 0f676e2 [Joseph K. Bradley] Optimizations + Bug fix for DecisionTree 3211f02 [Joseph K. Bradley] Optimizing DecisionTree * Added TreePoint representation to avoid calling findBin multiple times. * (not working yet, but debugging) f61e9d2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing bcf874a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing 511ec85 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing a95bc22 [Joseph K. Bradley] timing for DecisionTree internals
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala167
1 files changed, 95 insertions, 72 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index a5c49a38dc..2f36fd9077 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -23,10 +23,10 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
-import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy}
-import org.apache.spark.mllib.tree.impl.TreePoint
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
-import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split}
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.regression.LabeledPoint
@@ -64,7 +64,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, 3, 2, 100)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(bins.length === 2)
assert(splits(0).length === 99)
@@ -82,7 +83,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(bins.length === 2)
assert(splits(0).length === 99)
@@ -162,7 +164,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
// Check splits.
@@ -279,7 +282,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 100,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
// Expecting 2^2 - 1 = 3 bins/splits
assert(splits(0)(0).feature === 0)
@@ -373,7 +377,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 100,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 10, 1-> 10))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
// 2^10 - 1 > 100, so categorical variables will be ordered
@@ -428,10 +433,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
maxDepth = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
- val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0,
- Array[List[Filter]](), splits, bins, 10)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
val split = bestSplits(0)._1
assert(split.categories.length === 1)
@@ -456,10 +462,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
maxDepth = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
- val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0,
- Array[List[Filter]](), splits, bins, 10)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
val split = bestSplits(0)._1
assert(split.categories.length === 1)
@@ -495,7 +502,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, 3, 2, 100)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
@@ -503,9 +511,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins(0).length === 100)
- val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0,
- Array[List[Filter]](), splits, bins, 10)
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._2.gain === 0)
@@ -518,7 +526,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, 3, 2, 100)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
@@ -526,9 +535,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins(0).length === 100)
- val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
- val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0,
- Array[List[Filter]](), splits, bins, 10)
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._2.gain === 0)
@@ -542,7 +551,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
@@ -550,9 +560,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins(0).length === 100)
- val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
- val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0,
- Array[List[Filter]](), splits, bins, 10)
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._2.gain === 0)
@@ -566,7 +576,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
@@ -574,9 +585,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins(0).length === 100)
- val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
- val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0,
- Array[List[Filter]](), splits, bins, 10)
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._2.gain === 0)
@@ -590,7 +601,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
@@ -598,14 +610,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins(0).length === 100)
- val leftFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()), -1)
- val rightFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()) ,1)
- val filters = Array[List[Filter]](List(), List(leftFilter), List(rightFilter))
+ // Train a 1-node model
+ val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100)
+ val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
+ val nodes: Array[Node] = new Array[Node](7)
+ nodes(0) = modelOneNode.topNode
+ nodes(0).leftNode = None
+ nodes(0).rightNode = None
+
val parentImpurities = Array(0.5, 0.5, 0.5)
// Single group second level tree construction.
- val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
- val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, filters,
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1, nodes,
splits, bins, 10)
assert(bestSplits.length === 2)
assert(bestSplits(0)._2.gain > 0)
@@ -613,8 +630,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
// maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second
// level tree construction.
- val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1,
- filters, splits, bins, 0)
+ val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1,
+ nodes, splits, bins, 0)
assert(bestSplitsWithGroups.length === 2)
assert(bestSplitsWithGroups(0)._2.gain > 0)
assert(bestSplitsWithGroups(1)._2.gain > 0)
@@ -629,19 +646,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity)
assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict)
}
-
}
test("stump with categorical variables for multiclass classification") {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
- val input = sc.parallelize(arr)
+ val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
assert(strategy.isMulticlassClassification)
- val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
- val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
- Array[List[Filter]](), splits, bins, 10)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
val bestSplit = bestSplits(0)._1
@@ -657,11 +674,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0))
arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0))
arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0))
- val input = sc.parallelize(arr)
+ val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 2)
- val model = DecisionTree.train(input, strategy)
+ val model = DecisionTree.train(rdd, strategy)
validateClassifier(model, arr, 1.0)
assert(model.numNodes === 3)
assert(model.depth === 1)
@@ -688,20 +705,22 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
test("stump with categorical variables for multiclass classification, with just enough bins") {
val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
- val input = sc.parallelize(arr)
+ val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
- numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
+ numClassesForClassification = 3, maxBins = maxBins,
+ categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
assert(strategy.isMulticlassClassification)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
- val model = DecisionTree.train(input, strategy)
+ val model = DecisionTree.train(rdd, strategy)
validateClassifier(model, arr, 1.0)
assert(model.numNodes === 3)
assert(model.depth === 1)
- val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
- val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
- Array[List[Filter]](), splits, bins, 10)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
val bestSplit = bestSplits(0)._1
@@ -716,18 +735,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
test("stump with continuous variables for multiclass classification") {
val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
- val input = sc.parallelize(arr)
+ val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3)
assert(strategy.isMulticlassClassification)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
- val model = DecisionTree.train(input, strategy)
+ val model = DecisionTree.train(rdd, strategy)
validateClassifier(model, arr, 0.9)
- val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
- val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
- Array[List[Filter]](), splits, bins, 10)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
val bestSplit = bestSplits(0)._1
@@ -741,18 +761,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
test("stump with continuous + categorical variables for multiclass classification") {
val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
- val input = sc.parallelize(arr)
+ val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3))
assert(strategy.isMulticlassClassification)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
- val model = DecisionTree.train(input, strategy)
+ val model = DecisionTree.train(rdd, strategy)
validateClassifier(model, arr, 0.9)
- val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
- val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
- Array[List[Filter]](), splits, bins, 10)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
val bestSplit = bestSplits(0)._1
@@ -765,14 +786,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
test("stump with categorical variables for ordered multiclass classification") {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
- val input = sc.parallelize(arr)
+ val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
assert(strategy.isMulticlassClassification)
- val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
- val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
- Array[List[Filter]](), splits, bins, 10)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+ new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
val bestSplit = bestSplits(0)._1