From 1614485fd92fc94bc3989da49be612e542b93fb8 Mon Sep 17 00:00:00 2001 From: sethah Date: Thu, 17 Mar 2016 16:44:41 -0700 Subject: [SPARK-10788][MLLIB][ML] Remove duplicate bins for decision trees Decision trees in spark.ml (RandomForest.scala) communicate twice as much data as needed for unordered categorical features. Here's an example. Say there are 3 categories A, B, C. We consider 3 splits: * A vs. B, C * A, B vs. C * A, C vs. B Currently, we collect statistics for each of the 6 subsets of categories (3 * 2 = 6). However, we could instead collect statistics for the 3 subsets on the left-hand side of the 3 possible splits: A and A,B and A,C. If we also have stats for the entire node, then we can compute the stats for the 3 subsets on the right-hand side of the splits. In pseudomath: stats(B,C) = stats(A,B,C) - stats(A). This patch adds a parent stats array to the `DTStatsAggregator` so that the right child stats do not need to be stored. The right child stats are computed by subtracting left child stats from the parent stats for unordered categorical features. Author: sethah Closes #9474 from sethah/SPARK-10788. --- .../test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'mllib/src/test') diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 5518bdf527..89b64fce96 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -189,6 +189,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(bins.length === 2) assert(splits(0).length === 3) assert(bins(0).length === 0) + assert(metadata.numSplits(0) === 3) + assert(metadata.numBins(0) === 3) + assert(metadata.numSplits(1) === 3) + assert(metadata.numBins(1) === 3) // Expecting 2^2 - 1 = 3 bins/splits assert(splits(0)(0).feature === 0) -- cgit v1.2.3