aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala
diff options
context:
space:
mode:
authorNathan Howell <nhowell@godaddy.com>2015-10-07 17:46:16 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-10-07 17:46:16 -0700
commit1bc435ae3afb7a007b8a8ff00dcad4738a9ff055 (patch)
tree9e8ff1046739f58d90941d027faf4bf91f76eac6 /mllib/src/test/scala
parent075a0b658289608c8732e07e26e14d736e673ce9 (diff)
downloadspark-1bc435ae3afb7a007b8a8ff00dcad4738a9ff055.tar.gz
spark-1bc435ae3afb7a007b8a8ff00dcad4738a9ff055.tar.bz2
spark-1bc435ae3afb7a007b8a8ff00dcad4738a9ff055.zip
[SPARK-10064] [ML] Parallelize decision tree bin split calculations
Reimplement `DecisionTree.findSplitsBins` via `RDD` to parallelize bin calculation. With large feature spaces the current implementation is very slow. This change limits the features that are distributed (or collected) to just the continuous features, and performs the split calculations in parallel. It completes on a real multi terabyte dataset in less than a minute instead of multiple hours. Author: Nathan Howell <nhowell@godaddy.com> Closes #8246 from NathanHowell/SPARK-10064.
Diffstat (limited to 'mllib/src/test/scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala4
2 files changed, 2 insertions, 8 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 356d957f15..1a4299db4e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -135,8 +135,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits.length === 3)
- assert(fakeMetadata.numSplits(0) === 3)
- assert(fakeMetadata.numBins(0) === 4)
// check returned splits are distinct
assert(splits.distinct.length === splits.length)
}
@@ -151,8 +149,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits.length === 2)
- assert(fakeMetadata.numSplits(0) === 2)
- assert(fakeMetadata.numBins(0) === 3)
assert(splits(0) === 2.0)
assert(splits(1) === 3.0)
}
@@ -167,8 +163,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits.length === 1)
- assert(fakeMetadata.numSplits(0) === 1)
- assert(fakeMetadata.numBins(0) === 2)
assert(splits(0) === 1.0)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
index 334bf3790f..3d3f80063f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
@@ -69,8 +69,8 @@ object EnsembleTestHelper {
required: Double,
metricName: String = "mse") {
val predictions = input.map(x => model.predict(x.features))
- val errors = predictions.zip(input.map(_.label)).map { case (prediction, label) =>
- label - prediction
+ val errors = predictions.zip(input).map { case (prediction, point) =>
+ point.label - prediction
}
val metric = metricName match {
case "mse" =>