aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorManish Amde <manish9ue@gmail.com>2014-04-01 21:40:49 -0700
committerMatei Zaharia <matei@databricks.com>2014-04-01 21:40:49 -0700
commit8b3045ceab591a3f3ca18823c7e2c5faca38a06e (patch)
tree11fa8935ed2793e4c1b3e06b042b459e8a5aee27 /mllib/src/test
parent45df9127365f8942794273b8ada004bf6ea3ef10 (diff)
downloadspark-8b3045ceab591a3f3ca18823c7e2c5faca38a06e.tar.gz
spark-8b3045ceab591a3f3ca18823c7e2c5faca38a06e.tar.bz2
spark-8b3045ceab591a3f3ca18823c7e2c5faca38a06e.zip
MLI-1 Decision Trees
Joint work with @hirakendu, @etrain, @atalwalkar and @harsha2010. Key features: + Supports binary classification and regression + Supports gini, entropy and variance for information gain calculation + Supports both continuous and categorical features The algorithm has gone through several development iterations over the last few months leading to a highly optimized implementation. Optimizations include: 1. Level-wise training to reduce passes over the entire dataset. 2. Bin-wise split calculation to reduce computation overhead. 3. Aggregation over partitions before combining to reduce communication overhead. Author: Manish Amde <manish9ue@gmail.com> Author: manishamde <manish9ue@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Closes #79 from manishamde/tree and squashes the following commits: 1e8c704 [Manish Amde] remove numBins field in the Strategy class 7d54b4f [manishamde] Merge pull request #4 from mengxr/dtree f536ae9 [Xiangrui Meng] another pass on code style e1dd86f [Manish Amde] implementing code style suggestions 62dc723 [Manish Amde] updating javadoc and converting helper methods to package private to allow unit testing 201702f [Manish Amde] making some more methods private f963ef5 [Manish Amde] making methods private c487e6a [manishamde] Merge pull request #1 from mengxr/dtree 24500c5 [Xiangrui Meng] minor style updates 4576b64 [Manish Amde] documentation and for to while loop conversion ff363a7 [Manish Amde] binary search for bins and while loop for categorical feature bins 632818f [Manish Amde] removing threshold for classification predict method 2116360 [Manish Amde] removing dummy bin calculation for categorical variables 6068356 [Manish Amde] ensuring num bins is always greater than max number of categories 62c2562 [Manish Amde] fixing comment indentation ad1fc21 [Manish Amde] incorporated mengxr's code style suggestions d1ef4f6 [Manish Amde] more documentation 794ff4d [Manish Amde] minor improvements to docs and style eb8fcbe [Manish Amde] minor code style updates cd2c2b4 [Manish Amde] fixing code style based on feedback 63e786b [Manish Amde] added multiple train methods for java compatability d3023b3 [Manish Amde] adding more docs for nested methods 84f85d6 [Manish Amde] code documentation 9372779 [Manish Amde] code style: max line lenght <= 100 dd0c0d7 [Manish Amde] minor: some docs 0dd7659 [manishamde] basic doc 5841c28 [Manish Amde] unit tests for categorical features f067d68 [Manish Amde] minor cleanup c0e522b [Manish Amde] updated predict and split threshold logic b09dc98 [Manish Amde] minor refactoring 6b7de78 [Manish Amde] minor refactoring and tests d504eb1 [Manish Amde] more tests for categorical features dbb7ac1 [Manish Amde] categorical feature support 6df35b9 [Manish Amde] regression predict logic 53108ed [Manish Amde] fixing index for highest bin e23c2e5 [Manish Amde] added regression support c8f6d60 [Manish Amde] adding enum for feature type b0e3e76 [Manish Amde] adding enum for feature type 154aa77 [Manish Amde] enums for configurations 733d6dd [Manish Amde] fixed tests 02c595c [Manish Amde] added command line parsing 98ec8d5 [Manish Amde] tree building and prediction logic b0eb866 [Manish Amde] added logic to handle leaf nodes 80e8c66 [Manish Amde] working version of multi-level split calculation 4798aae [Manish Amde] added gain stats class dad0afc [Manish Amde] decison stump functionality working 03f534c [Manish Amde] some more tests 0012a77 [Manish Amde] basic stump working 8bca1e2 [Manish Amde] additional code for creating intermediate RDD 92cedce [Manish Amde] basic building blocks for intermediate RDD calculation. untested. cd53eae [Manish Amde] skeletal framework
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala425
1 files changed, 425 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
new file mode 100644
index 0000000000..4349c7000a
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -0,0 +1,425 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
+import org.apache.spark.mllib.tree.model.Filter
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.FeatureType._
+
+class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
+
+ @transient private var sc: SparkContext = _
+
+ override def beforeAll() {
+ sc = new SparkContext("local", "test")
+ }
+
+ override def afterAll() {
+ sc.stop()
+ System.clearProperty("spark.driver.port")
+ }
+
+ test("split and bin calculation") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Gini, 3, 100)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(bins.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+ }
+
+ test("split and bin calculation for categorical variables") {
+ val arr = DecisionTreeSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(
+ Classification,
+ Gini,
+ maxDepth = 3,
+ maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(bins.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+
+ // Check splits.
+
+ assert(splits(0)(0).feature === 0)
+ assert(splits(0)(0).threshold === Double.MinValue)
+ assert(splits(0)(0).featureType === Categorical)
+ assert(splits(0)(0).categories.length === 1)
+ assert(splits(0)(0).categories.contains(1.0))
+
+ assert(splits(0)(1).feature === 0)
+ assert(splits(0)(1).threshold === Double.MinValue)
+ assert(splits(0)(1).featureType === Categorical)
+ assert(splits(0)(1).categories.length === 2)
+ assert(splits(0)(1).categories.contains(1.0))
+ assert(splits(0)(1).categories.contains(0.0))
+
+ assert(splits(0)(2) === null)
+
+ assert(splits(1)(0).feature === 1)
+ assert(splits(1)(0).threshold === Double.MinValue)
+ assert(splits(1)(0).featureType === Categorical)
+ assert(splits(1)(0).categories.length === 1)
+ assert(splits(1)(0).categories.contains(0.0))
+
+ assert(splits(1)(1).feature === 1)
+ assert(splits(1)(1).threshold === Double.MinValue)
+ assert(splits(1)(1).featureType === Categorical)
+ assert(splits(1)(1).categories.length === 2)
+ assert(splits(1)(1).categories.contains(1.0))
+ assert(splits(1)(1).categories.contains(0.0))
+
+ assert(splits(1)(2) === null)
+
+ // Check bins.
+
+ assert(bins(0)(0).category === 1.0)
+ assert(bins(0)(0).lowSplit.categories.length === 0)
+ assert(bins(0)(0).highSplit.categories.length === 1)
+ assert(bins(0)(0).highSplit.categories.contains(1.0))
+
+ assert(bins(0)(1).category === 0.0)
+ assert(bins(0)(1).lowSplit.categories.length === 1)
+ assert(bins(0)(1).lowSplit.categories.contains(1.0))
+ assert(bins(0)(1).highSplit.categories.length === 2)
+ assert(bins(0)(1).highSplit.categories.contains(1.0))
+ assert(bins(0)(1).highSplit.categories.contains(0.0))
+
+ assert(bins(0)(2) === null)
+
+ assert(bins(1)(0).category === 0.0)
+ assert(bins(1)(0).lowSplit.categories.length === 0)
+ assert(bins(1)(0).highSplit.categories.length === 1)
+ assert(bins(1)(0).highSplit.categories.contains(0.0))
+
+ assert(bins(1)(1).category === 1.0)
+ assert(bins(1)(1).lowSplit.categories.length === 1)
+ assert(bins(1)(1).lowSplit.categories.contains(0.0))
+ assert(bins(1)(1).highSplit.categories.length === 2)
+ assert(bins(1)(1).highSplit.categories.contains(0.0))
+ assert(bins(1)(1).highSplit.categories.contains(1.0))
+
+ assert(bins(1)(2) === null)
+ }
+
+ test("split and bin calculations for categorical variables with no sample for one category") {
+ val arr = DecisionTreeSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(
+ Classification,
+ Gini,
+ maxDepth = 3,
+ maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+
+ // Check splits.
+
+ assert(splits(0)(0).feature === 0)
+ assert(splits(0)(0).threshold === Double.MinValue)
+ assert(splits(0)(0).featureType === Categorical)
+ assert(splits(0)(0).categories.length === 1)
+ assert(splits(0)(0).categories.contains(1.0))
+
+ assert(splits(0)(1).feature === 0)
+ assert(splits(0)(1).threshold === Double.MinValue)
+ assert(splits(0)(1).featureType === Categorical)
+ assert(splits(0)(1).categories.length === 2)
+ assert(splits(0)(1).categories.contains(1.0))
+ assert(splits(0)(1).categories.contains(0.0))
+
+ assert(splits(0)(2).feature === 0)
+ assert(splits(0)(2).threshold === Double.MinValue)
+ assert(splits(0)(2).featureType === Categorical)
+ assert(splits(0)(2).categories.length === 3)
+ assert(splits(0)(2).categories.contains(1.0))
+ assert(splits(0)(2).categories.contains(0.0))
+ assert(splits(0)(2).categories.contains(2.0))
+
+ assert(splits(0)(3) === null)
+
+ assert(splits(1)(0).feature === 1)
+ assert(splits(1)(0).threshold === Double.MinValue)
+ assert(splits(1)(0).featureType === Categorical)
+ assert(splits(1)(0).categories.length === 1)
+ assert(splits(1)(0).categories.contains(0.0))
+
+ assert(splits(1)(1).feature === 1)
+ assert(splits(1)(1).threshold === Double.MinValue)
+ assert(splits(1)(1).featureType === Categorical)
+ assert(splits(1)(1).categories.length === 2)
+ assert(splits(1)(1).categories.contains(1.0))
+ assert(splits(1)(1).categories.contains(0.0))
+
+ assert(splits(1)(2).feature === 1)
+ assert(splits(1)(2).threshold === Double.MinValue)
+ assert(splits(1)(2).featureType === Categorical)
+ assert(splits(1)(2).categories.length === 3)
+ assert(splits(1)(2).categories.contains(1.0))
+ assert(splits(1)(2).categories.contains(0.0))
+ assert(splits(1)(2).categories.contains(2.0))
+
+ assert(splits(1)(3) === null)
+
+ // Check bins.
+
+ assert(bins(0)(0).category === 1.0)
+ assert(bins(0)(0).lowSplit.categories.length === 0)
+ assert(bins(0)(0).highSplit.categories.length === 1)
+ assert(bins(0)(0).highSplit.categories.contains(1.0))
+
+ assert(bins(0)(1).category === 0.0)
+ assert(bins(0)(1).lowSplit.categories.length === 1)
+ assert(bins(0)(1).lowSplit.categories.contains(1.0))
+ assert(bins(0)(1).highSplit.categories.length === 2)
+ assert(bins(0)(1).highSplit.categories.contains(1.0))
+ assert(bins(0)(1).highSplit.categories.contains(0.0))
+
+ assert(bins(0)(2).category === 2.0)
+ assert(bins(0)(2).lowSplit.categories.length === 2)
+ assert(bins(0)(2).lowSplit.categories.contains(1.0))
+ assert(bins(0)(2).lowSplit.categories.contains(0.0))
+ assert(bins(0)(2).highSplit.categories.length === 3)
+ assert(bins(0)(2).highSplit.categories.contains(1.0))
+ assert(bins(0)(2).highSplit.categories.contains(0.0))
+ assert(bins(0)(2).highSplit.categories.contains(2.0))
+
+ assert(bins(0)(3) === null)
+
+ assert(bins(1)(0).category === 0.0)
+ assert(bins(1)(0).lowSplit.categories.length === 0)
+ assert(bins(1)(0).highSplit.categories.length === 1)
+ assert(bins(1)(0).highSplit.categories.contains(0.0))
+
+ assert(bins(1)(1).category === 1.0)
+ assert(bins(1)(1).lowSplit.categories.length === 1)
+ assert(bins(1)(1).lowSplit.categories.contains(0.0))
+ assert(bins(1)(1).highSplit.categories.length === 2)
+ assert(bins(1)(1).highSplit.categories.contains(0.0))
+ assert(bins(1)(1).highSplit.categories.contains(1.0))
+
+ assert(bins(1)(2).category === 2.0)
+ assert(bins(1)(2).lowSplit.categories.length === 2)
+ assert(bins(1)(2).lowSplit.categories.contains(0.0))
+ assert(bins(1)(2).lowSplit.categories.contains(1.0))
+ assert(bins(1)(2).highSplit.categories.length === 3)
+ assert(bins(1)(2).highSplit.categories.contains(0.0))
+ assert(bins(1)(2).highSplit.categories.contains(1.0))
+ assert(bins(1)(2).highSplit.categories.contains(2.0))
+
+ assert(bins(1)(3) === null)
+ }
+
+ test("classification stump with all categorical variables") {
+ val arr = DecisionTreeSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(
+ Classification,
+ Gini,
+ maxDepth = 3,
+ maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
+ Array[List[Filter]](), splits, bins)
+
+ val split = bestSplits(0)._1
+ assert(split.categories.length === 1)
+ assert(split.categories.contains(1.0))
+ assert(split.featureType === Categorical)
+ assert(split.threshold === Double.MinValue)
+
+ val stats = bestSplits(0)._2
+ assert(stats.gain > 0)
+ assert(stats.predict > 0.4)
+ assert(stats.predict < 0.5)
+ assert(stats.impurity > 0.2)
+ }
+
+ test("regression stump with all categorical variables") {
+ val arr = DecisionTreeSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(
+ Regression,
+ Variance,
+ maxDepth = 3,
+ maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
+ val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
+ Array[List[Filter]](), splits, bins)
+
+ val split = bestSplits(0)._1
+ assert(split.categories.length === 1)
+ assert(split.categories.contains(1.0))
+ assert(split.featureType === Categorical)
+ assert(split.threshold === Double.MinValue)
+
+ val stats = bestSplits(0)._2
+ assert(stats.gain > 0)
+ assert(stats.predict > 0.4)
+ assert(stats.predict < 0.5)
+ assert(stats.impurity > 0.2)
+ }
+
+ test("stump with fixed label 0 for Gini") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Gini, 3, 100)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins.length === 2)
+ assert(bins(0).length === 100)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+
+ val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
+ Array[List[Filter]](), splits, bins)
+ assert(bestSplits.length === 1)
+ assert(bestSplits(0)._1.feature === 0)
+ assert(bestSplits(0)._1.threshold === 10)
+ assert(bestSplits(0)._2.gain === 0)
+ assert(bestSplits(0)._2.leftImpurity === 0)
+ assert(bestSplits(0)._2.rightImpurity === 0)
+ }
+
+ test("stump with fixed label 1 for Gini") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Gini, 3, 100)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins.length === 2)
+ assert(bins(0).length === 100)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+
+ val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
+ Array[List[Filter]](), splits, bins)
+ assert(bestSplits.length === 1)
+ assert(bestSplits(0)._1.feature === 0)
+ assert(bestSplits(0)._1.threshold === 10)
+ assert(bestSplits(0)._2.gain === 0)
+ assert(bestSplits(0)._2.leftImpurity === 0)
+ assert(bestSplits(0)._2.rightImpurity === 0)
+ assert(bestSplits(0)._2.predict === 1)
+ }
+
+ test("stump with fixed label 0 for Entropy") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Entropy, 3, 100)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins.length === 2)
+ assert(bins(0).length === 100)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+
+ val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
+ Array[List[Filter]](), splits, bins)
+ assert(bestSplits.length === 1)
+ assert(bestSplits(0)._1.feature === 0)
+ assert(bestSplits(0)._1.threshold === 10)
+ assert(bestSplits(0)._2.gain === 0)
+ assert(bestSplits(0)._2.leftImpurity === 0)
+ assert(bestSplits(0)._2.rightImpurity === 0)
+ assert(bestSplits(0)._2.predict === 0)
+ }
+
+ test("stump with fixed label 1 for Entropy") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Entropy, 3, 100)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins.length === 2)
+ assert(bins(0).length === 100)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+
+ val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
+ Array[List[Filter]](), splits, bins)
+ assert(bestSplits.length === 1)
+ assert(bestSplits(0)._1.feature === 0)
+ assert(bestSplits(0)._1.threshold === 10)
+ assert(bestSplits(0)._2.gain === 0)
+ assert(bestSplits(0)._2.leftImpurity === 0)
+ assert(bestSplits(0)._2.rightImpurity === 0)
+ assert(bestSplits(0)._2.predict === 1)
+ }
+}
+
+object DecisionTreeSuite {
+
+ def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = {
+ val arr = new Array[LabeledPoint](1000)
+ for (i <- 0 until 1000){
+ val lp = new LabeledPoint(0.0,Array(i.toDouble,1000.0-i))
+ arr(i) = lp
+ }
+ arr
+ }
+
+ def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = {
+ val arr = new Array[LabeledPoint](1000)
+ for (i <- 0 until 1000){
+ val lp = new LabeledPoint(1.0,Array(i.toDouble,999.0-i))
+ arr(i) = lp
+ }
+ arr
+ }
+
+ def generateCategoricalDataPoints(): Array[LabeledPoint] = {
+ val arr = new Array[LabeledPoint](1000)
+ for (i <- 0 until 1000){
+ if (i < 600){
+ arr(i) = new LabeledPoint(1.0,Array(0.0,1.0))
+ } else {
+ arr(i) = new LabeledPoint(0.0,Array(1.0,0.0))
+ }
+ }
+ arr
+ }
+}