aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2016-03-17 16:44:41 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-17 16:44:41 -0700
commit1614485fd92fc94bc3989da49be612e542b93fb8 (patch)
treeca71f1fee26fcb993316bd0a6e545ae2003ab107 /mllib/src/test
parentb39e80d39dae8e6779f9d78c1631a27585239032 (diff)
downloadspark-1614485fd92fc94bc3989da49be612e542b93fb8.tar.gz
spark-1614485fd92fc94bc3989da49be612e542b93fb8.tar.bz2
spark-1614485fd92fc94bc3989da49be612e542b93fb8.zip
[SPARK-10788][MLLIB][ML] Remove duplicate bins for decision trees
Decision trees in spark.ml (RandomForest.scala) communicate twice as much data as needed for unordered categorical features. Here's an example. Say there are 3 categories A, B, C. We consider 3 splits: * A vs. B, C * A, B vs. C * A, C vs. B Currently, we collect statistics for each of the 6 subsets of categories (3 * 2 = 6). However, we could instead collect statistics for the 3 subsets on the left-hand side of the 3 possible splits: A and A,B and A,C. If we also have stats for the entire node, then we can compute the stats for the 3 subsets on the right-hand side of the splits. In pseudomath: stats(B,C) = stats(A,B,C) - stats(A). This patch adds a parent stats array to the `DTStatsAggregator` so that the right child stats do not need to be stored. The right child stats are computed by subtracting left child stats from the parent stats for unordered categorical features. Author: sethah <seth.hendrickson16@gmail.com> Closes #9474 from sethah/SPARK-10788.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala4
1 files changed, 4 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 5518bdf527..89b64fce96 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -189,6 +189,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(bins.length === 2)
assert(splits(0).length === 3)
assert(bins(0).length === 0)
+ assert(metadata.numSplits(0) === 3)
+ assert(metadata.numBins(0) === 3)
+ assert(metadata.numSplits(1) === 3)
+ assert(metadata.numBins(1) === 3)
// Expecting 2^2 - 1 = 3 bins/splits
assert(splits(0)(0).feature === 0)