aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala84
1 files changed, 73 insertions, 11 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index be383aab71..35e92d71dc 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -22,7 +22,8 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.Filter
-import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.model.Split
+import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.linalg.Vectors
@@ -242,7 +243,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
- Array[List[Filter]](), splits, bins)
+ Array[List[Filter]](), splits, bins, 10)
val split = bestSplits(0)._1
assert(split.categories.length === 1)
@@ -269,7 +270,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
- Array[List[Filter]](), splits, bins)
+ Array[List[Filter]](), splits, bins, 10)
val split = bestSplits(0)._1
assert(split.categories.length === 1)
@@ -298,7 +299,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins(0).length === 100)
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
- Array[List[Filter]](), splits, bins)
+ Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._1.threshold === 10)
@@ -321,7 +322,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins(0).length === 100)
val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
- Array[List[Filter]](), splits, bins)
+ Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._1.threshold === 10)
@@ -345,7 +346,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins(0).length === 100)
val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
- Array[List[Filter]](), splits, bins)
+ Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._1.threshold === 10)
@@ -369,7 +370,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins(0).length === 100)
val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
- Array[List[Filter]](), splits, bins)
+ Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._1.threshold === 10)
@@ -378,13 +379,60 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(0)._2.rightImpurity === 0)
assert(bestSplits(0)._2.predict === 1)
}
+
+ test("test second level node building with/without groups") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Entropy, 3, 100)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins.length === 2)
+ assert(bins(0).length === 100)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+
+ val leftFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()), -1)
+ val rightFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()) ,1)
+ val filters = Array[List[Filter]](List(), List(leftFilter), List(rightFilter))
+ val parentImpurities = Array(0.5, 0.5, 0.5)
+
+ // Single group second level tree construction.
+ val bestSplits = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, filters,
+ splits, bins, 10)
+ assert(bestSplits.length === 2)
+ assert(bestSplits(0)._2.gain > 0)
+ assert(bestSplits(1)._2.gain > 0)
+
+ // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second
+ // level tree construction.
+ val bestSplitsWithGroups = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1,
+ filters, splits, bins, 0)
+ assert(bestSplitsWithGroups.length === 2)
+ assert(bestSplitsWithGroups(0)._2.gain > 0)
+ assert(bestSplitsWithGroups(1)._2.gain > 0)
+
+ // Verify whether the splits obtained using single group and multiple group level
+ // construction strategies are the same.
+ for (i <- 0 until bestSplits.length) {
+ assert(bestSplits(i)._1 === bestSplitsWithGroups(i)._1)
+ assert(bestSplits(i)._2.gain === bestSplitsWithGroups(i)._2.gain)
+ assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity)
+ assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity)
+ assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity)
+ assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict)
+ }
+
+ }
+
}
object DecisionTreeSuite {
def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
- for (i <- 0 until 1000){
+ for (i <- 0 until 1000) {
val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
arr(i) = lp
}
@@ -393,17 +441,31 @@ object DecisionTreeSuite {
def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
- for (i <- 0 until 1000){
+ for (i <- 0 until 1000) {
val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i))
arr(i) = lp
}
arr
}
+ def generateOrderedLabeledPoints(): Array[LabeledPoint] = {
+ val arr = new Array[LabeledPoint](1000)
+ for (i <- 0 until 1000) {
+ if (i < 600) {
+ val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
+ arr(i) = lp
+ } else {
+ val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i))
+ arr(i) = lp
+ }
+ }
+ arr
+ }
+
def generateCategoricalDataPoints(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
- for (i <- 0 until 1000){
- if (i < 600){
+ for (i <- 0 until 1000) {
+ if (i < 600) {
arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0))
} else {
arr(i) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0))