diff options
author | Nathan Howell <nhowell@godaddy.com> | 2015-10-07 17:46:16 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-10-07 17:46:16 -0700 |
commit | 1bc435ae3afb7a007b8a8ff00dcad4738a9ff055 (patch) | |
tree | 9e8ff1046739f58d90941d027faf4bf91f76eac6 /mllib/src/test/scala/org/apache | |
parent | 075a0b658289608c8732e07e26e14d736e673ce9 (diff) | |
download | spark-1bc435ae3afb7a007b8a8ff00dcad4738a9ff055.tar.gz spark-1bc435ae3afb7a007b8a8ff00dcad4738a9ff055.tar.bz2 spark-1bc435ae3afb7a007b8a8ff00dcad4738a9ff055.zip |
[SPARK-10064] [ML] Parallelize decision tree bin split calculations
Reimplement `DecisionTree.findSplitsBins` via `RDD` to parallelize bin calculation.
With large feature spaces the current implementation is very slow. This change limits the features that are distributed (or collected) to just the continuous features, and performs the split calculations in parallel. It completes on a real multi terabyte dataset in less than a minute instead of multiple hours.
Author: Nathan Howell <nhowell@godaddy.com>
Closes #8246 from NathanHowell/SPARK-10064.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala | 6 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala | 4 |
2 files changed, 2 insertions, 8 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 356d957f15..1a4299db4e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -135,8 +135,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) assert(splits.length === 3) - assert(fakeMetadata.numSplits(0) === 3) - assert(fakeMetadata.numBins(0) === 4) // check returned splits are distinct assert(splits.distinct.length === splits.length) } @@ -151,8 +149,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) assert(splits.length === 2) - assert(fakeMetadata.numSplits(0) === 2) - assert(fakeMetadata.numBins(0) === 3) assert(splits(0) === 2.0) assert(splits(1) === 3.0) } @@ -167,8 +163,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) assert(splits.length === 1) - assert(fakeMetadata.numSplits(0) === 1) - assert(fakeMetadata.numBins(0) === 2) assert(splits(0) === 1.0) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala index 334bf3790f..3d3f80063f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala @@ -69,8 +69,8 @@ object EnsembleTestHelper { required: Double, metricName: String = "mse") { val predictions = input.map(x => model.predict(x.features)) - val errors = predictions.zip(input.map(_.label)).map { case (prediction, label) => - label - prediction + val errors = predictions.zip(input).map { case (prediction, point) => + point.label - prediction } val metric = metricName match { case "mse" => |