diff options
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala | 84 |
1 files changed, 73 insertions, 11 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index be383aab71..35e92d71dc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -22,7 +22,8 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.Filter -import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.model.Split +import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.linalg.Vectors @@ -242,7 +243,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) val split = bestSplits(0)._1 assert(split.categories.length === 1) @@ -269,7 +270,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) val split = bestSplits(0)._1 assert(split.categories.length === 1) @@ -298,7 +299,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._1.threshold === 10) @@ -321,7 +322,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._1.threshold === 10) @@ -345,7 +346,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._1.threshold === 10) @@ -369,7 +370,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._1.threshold === 10) @@ -378,13 +379,60 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(0)._2.rightImpurity === 0) assert(bestSplits(0)._2.predict === 1) } + + test("test second level node building with/without groups") { + val arr = DecisionTreeSuite.generateOrderedLabeledPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Entropy, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + val leftFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()), -1) + val rightFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()) ,1) + val filters = Array[List[Filter]](List(), List(leftFilter), List(rightFilter)) + val parentImpurities = Array(0.5, 0.5, 0.5) + + // Single group second level tree construction. + val bestSplits = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, filters, + splits, bins, 10) + assert(bestSplits.length === 2) + assert(bestSplits(0)._2.gain > 0) + assert(bestSplits(1)._2.gain > 0) + + // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second + // level tree construction. + val bestSplitsWithGroups = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, + filters, splits, bins, 0) + assert(bestSplitsWithGroups.length === 2) + assert(bestSplitsWithGroups(0)._2.gain > 0) + assert(bestSplitsWithGroups(1)._2.gain > 0) + + // Verify whether the splits obtained using single group and multiple group level + // construction strategies are the same. + for (i <- 0 until bestSplits.length) { + assert(bestSplits(i)._1 === bestSplitsWithGroups(i)._1) + assert(bestSplits(i)._2.gain === bestSplitsWithGroups(i)._2.gain) + assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity) + assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity) + assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity) + assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict) + } + + } + } object DecisionTreeSuite { def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) - for (i <- 0 until 1000){ + for (i <- 0 until 1000) { val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) arr(i) = lp } @@ -393,17 +441,31 @@ object DecisionTreeSuite { def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) - for (i <- 0 until 1000){ + for (i <- 0 until 1000) { val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i)) arr(i) = lp } arr } + def generateOrderedLabeledPoints(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) + for (i <- 0 until 1000) { + if (i < 600) { + val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) + arr(i) = lp + } else { + val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i)) + arr(i) = lp + } + } + arr + } + def generateCategoricalDataPoints(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) - for (i <- 0 until 1000){ - if (i < 600){ + for (i <- 0 until 1000) { + if (i < 600) { arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)) } else { arr(i) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0)) |