aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSean Owen <sowen@cloudera.com>2015-03-26 15:00:23 +0000
committerSean Owen <sowen@cloudera.com>2015-03-26 15:00:42 +0000
commit758ebf77d7daded7c5f6f41ee269205bc246d487 (patch)
tree56e889c9691fb0795bcaf1890853ca208ae95e23
parent61c059a4ace4007cccbb3ffcc2a382acdaf7196a (diff)
downloadspark-758ebf77d7daded7c5f6f41ee269205bc246d487.tar.gz
spark-758ebf77d7daded7c5f6f41ee269205bc246d487.tar.bz2
spark-758ebf77d7daded7c5f6f41ee269205bc246d487.zip
SPARK-6480 [CORE] histogram() bucket function is wrong in some simple edge cases
Fix fastBucketFunction for histogram() to handle edge conditions more correctly. Add a test, and fix existing one accordingly Author: Sean Owen <sowen@cloudera.com> Closes #5148 from srowen/SPARK-6480 and squashes the following commits: 974a0a0 [Sean Owen] Additional test of huge ranges, and a few more comments (and comment fixes) 23ec01e [Sean Owen] Fix fastBucketFunction for histogram() to handle edge conditions more correctly. Add a test, and fix existing one accordingly (cherry picked from commit fe15ea976073edd738c006af1eb8d31617a039fc) Signed-off-by: Sean Owen <sowen@cloudera.com>
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala20
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala24
2 files changed, 29 insertions, 15 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
index e0494ee396..e66c06e5be 100644
--- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
@@ -192,25 +192,23 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
}
}
// Determine the bucket function in constant time. Requires that buckets are evenly spaced
- def fastBucketFunction(min: Double, increment: Double, count: Int)(e: Double): Option[Int] = {
+ def fastBucketFunction(min: Double, max: Double, count: Int)(e: Double): Option[Int] = {
// If our input is not a number unless the increment is also NaN then we fail fast
- if (e.isNaN()) {
- return None
- }
- val bucketNumber = (e - min)/(increment)
- // We do this rather than buckets.lengthCompare(bucketNumber)
- // because Array[Double] fails to override it (for now).
- if (bucketNumber > count || bucketNumber < 0) {
+ if (e.isNaN || e < min || e > max) {
None
} else {
- Some(bucketNumber.toInt.min(count - 1))
+ // Compute ratio of e's distance along range to total range first, for better precision
+ val bucketNumber = (((e - min) / (max - min)) * count).toInt
+ // should be less than count, but will equal count if e == max, in which case
+ // it's part of the last end-range-inclusive bucket, so return count-1
+ Some(math.min(bucketNumber, count - 1))
}
}
// Decide which bucket function to pass to histogramPartition. We decide here
- // rather than having a general function so that the decission need only be made
+ // rather than having a general function so that the decision need only be made
// once rather than once per shard
val bucketFunction = if (evenBuckets) {
- fastBucketFunction(buckets(0), buckets(1)-buckets(0), buckets.length-1) _
+ fastBucketFunction(buckets.head, buckets.last, buckets.length - 1) _
} else {
basicBucketFunction _
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala
index f89bdb6e07..e29ac0c4fc 100644
--- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala
@@ -233,6 +233,12 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext {
assert(histogramBuckets === expectedHistogramBuckets)
}
+ test("WorksWithDoubleValuesAtMinMax") {
+ val rdd = sc.parallelize(Seq(1, 1, 1, 2, 3, 3))
+ assert(Array(3, 0, 1, 2) === rdd.map(_.toDouble).histogram(4)._2)
+ assert(Array(3, 1, 2) === rdd.map(_.toDouble).histogram(3)._2)
+ }
+
test("WorksWithoutBucketsWithMoreRequestedThanElements") {
// Verify the basic case of one bucket and all elements in that bucket works
val rdd = sc.parallelize(Seq(1, 2))
@@ -246,7 +252,7 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext {
}
test("WorksWithoutBucketsForLargerDatasets") {
- // Verify the case of slighly larger datasets
+ // Verify the case of slightly larger datasets
val rdd = sc.parallelize(6 to 99)
val (histogramBuckets, histogramResults) = rdd.histogram(8)
val expectedHistogramResults =
@@ -257,17 +263,27 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext {
assert(histogramBuckets === expectedHistogramBuckets)
}
- test("WorksWithoutBucketsWithIrrationalBucketEdges") {
- // Verify the case of buckets with irrational edges. See #SPARK-2862.
+ test("WorksWithoutBucketsWithNonIntegralBucketEdges") {
+ // Verify the case of buckets with nonintegral edges. See #SPARK-2862.
val rdd = sc.parallelize(6 to 99)
val (histogramBuckets, histogramResults) = rdd.histogram(9)
+ // Buckets are 6.0, 16.333333333333336, 26.666666666666668, 37.0, 47.333333333333336 ...
val expectedHistogramResults =
- Array(11, 10, 11, 10, 10, 11, 10, 10, 11)
+ Array(11, 10, 10, 11, 10, 10, 11, 10, 11)
assert(histogramResults === expectedHistogramResults)
assert(histogramBuckets(0) === 6.0)
assert(histogramBuckets(9) === 99.0)
}
+ test("WorksWithHugeRange") {
+ val rdd = sc.parallelize(Array(0, 1.0e24, 1.0e30))
+ val histogramResults = rdd.histogram(1000000)._2
+ assert(histogramResults(0) === 1)
+ assert(histogramResults(1) === 1)
+ assert(histogramResults.last === 1)
+ assert((2 to histogramResults.length - 2).forall(i => histogramResults(i) == 0))
+ }
+
// Test the failure mode with an invalid RDD
test("ThrowsExceptionOnInvalidRDDs") {
// infinity