aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-11-26 00:00:07 -0800
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-11-26 00:00:07 -0800
commit18d6df0e171454ada4d260bfe8b909eedf25304f (patch)
tree8982f01c8096b321c71b141874260c18a04b481e /core
parent0e2109ddb2f27d8a6a9f125206674273b03d1f5e (diff)
parent7222ee29779c3c5146aa5a3d6d060f3b039c1ff7 (diff)
downloadspark-18d6df0e171454ada4d260bfe8b909eedf25304f.tar.gz
spark-18d6df0e171454ada4d260bfe8b909eedf25304f.tar.bz2
spark-18d6df0e171454ada4d260bfe8b909eedf25304f.zip
Merge pull request #86 from holdenk/master
Add histogram functionality to DoubleRDDFunctions This pull request add histogram functionality to the DoubleRDDFunctions.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala40
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala126
-rw-r--r--core/src/test/scala/org/apache/spark/JavaAPISuite.java14
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala271
4 files changed, 451 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
index 043cb183ba..9f02a9b7d3 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -26,6 +26,8 @@ import org.apache.spark.storage.StorageLevel
import java.lang.Double
import org.apache.spark.Partitioner
+import scala.collection.JavaConverters._
+
class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, JavaDoubleRDD] {
override val classManifest: ClassManifest[Double] = implicitly[ClassManifest[Double]]
@@ -182,6 +184,44 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
/** (Experimental) Approximate operation to return the sum within a timeout. */
def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout)
+
+ /**
+ * Compute a histogram of the data using bucketCount number of buckets evenly
+ * spaced between the minimum and maximum of the RDD. For example if the min
+ * value is 0 and the max is 100 and there are two buckets the resulting
+ * buckets will be [0,50) [50,100]. bucketCount must be at least 1
+ * If the RDD contains infinity, NaN throws an exception
+ * If the elements in RDD do not vary (max == min) always returns a single bucket.
+ */
+ def histogram(bucketCount: Int): Pair[Array[scala.Double], Array[Long]] = {
+ val result = srdd.histogram(bucketCount)
+ (result._1, result._2)
+ }
+
+ /**
+ * Compute a histogram using the provided buckets. The buckets are all open
+ * to the left except for the last which is closed
+ * e.g. for the array
+ * [1,10,20,50] the buckets are [1,10) [10,20) [20,50]
+ * e.g 1<=x<10 , 10<=x<20, 20<=x<50
+ * And on the input of 1 and 50 we would have a histogram of 1,0,0
+ *
+ * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched
+ * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets
+ * to true.
+ * buckets must be sorted and not contain any duplicates.
+ * buckets array must be at least two elements
+ * All NaN entries are treated the same. If you have a NaN bucket it must be
+ * the maximum value of the last position and all NaN entries will be counted
+ * in that bucket.
+ */
+ def histogram(buckets: Array[scala.Double]): Array[Long] = {
+ srdd.histogram(buckets, false)
+ }
+
+ def histogram(buckets: Array[Double], evenBuckets: Boolean): Array[Long] = {
+ srdd.histogram(buckets.map(_.toDouble), evenBuckets)
+ }
}
object JavaDoubleRDD {
diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
index a4bec41752..02d75eccc5 100644
--- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
@@ -24,6 +24,8 @@ import org.apache.spark.partial.SumEvaluator
import org.apache.spark.util.StatCounter
import org.apache.spark.{TaskContext, Logging}
+import scala.collection.immutable.NumericRange
+
/**
* Extra functions available on RDDs of Doubles through an implicit conversion.
* Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
@@ -76,4 +78,128 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
val evaluator = new SumEvaluator(self.partitions.size, confidence)
self.context.runApproximateJob(self, processPartition, evaluator, timeout)
}
+
+ /**
+ * Compute a histogram of the data using bucketCount number of buckets evenly
+ * spaced between the minimum and maximum of the RDD. For example if the min
+ * value is 0 and the max is 100 and there are two buckets the resulting
+ * buckets will be [0, 50) [50, 100]. bucketCount must be at least 1
+ * If the RDD contains infinity, NaN throws an exception
+ * If the elements in RDD do not vary (max == min) always returns a single bucket.
+ */
+ def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = {
+ // Compute the minimum and the maxium
+ val (max: Double, min: Double) = self.mapPartitions { items =>
+ Iterator(items.foldRight(-1/0.0, Double.NaN)((e: Double, x: Pair[Double, Double]) =>
+ (x._1.max(e), x._2.min(e))))
+ }.reduce { (maxmin1, maxmin2) =>
+ (maxmin1._1.max(maxmin2._1), maxmin1._2.min(maxmin2._2))
+ }
+ if (max.isNaN() || max.isInfinity || min.isInfinity ) {
+ throw new UnsupportedOperationException(
+ "Histogram on either an empty RDD or RDD containing +/-infinity or NaN")
+ }
+ val increment = (max-min)/bucketCount.toDouble
+ val range = if (increment != 0) {
+ Range.Double.inclusive(min, max, increment)
+ } else {
+ List(min, min)
+ }
+ val buckets = range.toArray
+ (buckets, histogram(buckets, true))
+ }
+
+ /**
+ * Compute a histogram using the provided buckets. The buckets are all open
+ * to the left except for the last which is closed
+ * e.g. for the array
+ * [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50]
+ * e.g 1<=x<10 , 10<=x<20, 20<=x<50
+ * And on the input of 1 and 50 we would have a histogram of 1, 0, 0
+ *
+ * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched
+ * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets
+ * to true.
+ * buckets must be sorted and not contain any duplicates.
+ * buckets array must be at least two elements
+ * All NaN entries are treated the same. If you have a NaN bucket it must be
+ * the maximum value of the last position and all NaN entries will be counted
+ * in that bucket.
+ */
+ def histogram(buckets: Array[Double], evenBuckets: Boolean = false): Array[Long] = {
+ if (buckets.length < 2) {
+ throw new IllegalArgumentException("buckets array must have at least two elements")
+ }
+ // The histogramPartition function computes the partail histogram for a given
+ // partition. The provided bucketFunction determines which bucket in the array
+ // to increment or returns None if there is no bucket. This is done so we can
+ // specialize for uniformly distributed buckets and save the O(log n) binary
+ // search cost.
+ def histogramPartition(bucketFunction: (Double) => Option[Int])(iter: Iterator[Double]):
+ Iterator[Array[Long]] = {
+ val counters = new Array[Long](buckets.length - 1)
+ while (iter.hasNext) {
+ bucketFunction(iter.next()) match {
+ case Some(x: Int) => {counters(x) += 1}
+ case _ => {}
+ }
+ }
+ Iterator(counters)
+ }
+ // Merge the counters.
+ def mergeCounters(a1: Array[Long], a2: Array[Long]): Array[Long] = {
+ a1.indices.foreach(i => a1(i) += a2(i))
+ a1
+ }
+ // Basic bucket function. This works using Java's built in Array
+ // binary search. Takes log(size(buckets))
+ def basicBucketFunction(e: Double): Option[Int] = {
+ val location = java.util.Arrays.binarySearch(buckets, e)
+ if (location < 0) {
+ // If the location is less than 0 then the insertion point in the array
+ // to keep it sorted is -location-1
+ val insertionPoint = -location-1
+ // If we have to insert before the first element or after the last one
+ // its out of bounds.
+ // We do this rather than buckets.lengthCompare(insertionPoint)
+ // because Array[Double] fails to override it (for now).
+ if (insertionPoint > 0 && insertionPoint < buckets.length) {
+ Some(insertionPoint-1)
+ } else {
+ None
+ }
+ } else if (location < buckets.length - 1) {
+ // Exact match, just insert here
+ Some(location)
+ } else {
+ // Exact match to the last element
+ Some(location - 1)
+ }
+ }
+ // Determine the bucket function in constant time. Requires that buckets are evenly spaced
+ def fastBucketFunction(min: Double, increment: Double, count: Int)(e: Double): Option[Int] = {
+ // If our input is not a number unless the increment is also NaN then we fail fast
+ if (e.isNaN()) {
+ return None
+ }
+ val bucketNumber = (e - min)/(increment)
+ // We do this rather than buckets.lengthCompare(bucketNumber)
+ // because Array[Double] fails to override it (for now).
+ if (bucketNumber > count || bucketNumber < 0) {
+ None
+ } else {
+ Some(bucketNumber.toInt.min(count - 1))
+ }
+ }
+ // Decide which bucket function to pass to histogramPartition. We decide here
+ // rather than having a general function so that the decission need only be made
+ // once rather than once per shard
+ val bucketFunction = if (evenBuckets) {
+ fastBucketFunction(buckets(0), buckets(1)-buckets(0), buckets.length-1) _
+ } else {
+ basicBucketFunction _
+ }
+ self.mapPartitions(histogramPartition(bucketFunction)).reduce(mergeCounters)
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
index 352036f182..4234f6eac7 100644
--- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
@@ -365,6 +365,20 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void javaDoubleRDDHistoGram() {
+ JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0));
+ // Test using generated buckets
+ Tuple2<double[], long[]> results = rdd.histogram(2);
+ double[] expected_buckets = {1.0, 2.5, 4.0};
+ long[] expected_counts = {2, 2};
+ Assert.assertArrayEquals(expected_buckets, results._1, 0.1);
+ Assert.assertArrayEquals(expected_counts, results._2);
+ // Test with provided buckets
+ long[] histogram = rdd.histogram(expected_buckets);
+ Assert.assertArrayEquals(expected_counts, histogram);
+ }
+
+ @Test
public void map() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
JavaDoubleRDD doubles = rdd.map(new DoubleFunction<Integer>() {
diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala
new file mode 100644
index 0000000000..7f50a5a47c
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala
@@ -0,0 +1,271 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.math.abs
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd._
+import org.apache.spark._
+
+class DoubleRDDSuite extends FunSuite with SharedSparkContext {
+ // Verify tests on the histogram functionality. We test with both evenly
+ // and non-evenly spaced buckets as the bucket lookup function changes.
+ test("WorksOnEmpty") {
+ // Make sure that it works on an empty input
+ val rdd: RDD[Double] = sc.parallelize(Seq())
+ val buckets = Array(0.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksWithOutOfRangeWithOneBucket") {
+ // Verify that if all of the elements are out of range the counts are zero
+ val rdd = sc.parallelize(Seq(10.01, -0.01))
+ val buckets = Array(0.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithOneBucket") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+ val buckets = Array(0.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(4)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithOneBucketExactMatch") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+ val buckets = Array(1.0, 4.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(4)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksWithOutOfRangeWithTwoBuckets") {
+ // Verify that out of range works with two buckets
+ val rdd = sc.parallelize(Seq(10.01, -0.01))
+ val buckets = Array(0.0, 5.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(0, 0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksWithOutOfRangeWithTwoUnEvenBuckets") {
+ // Verify that out of range works with two un even buckets
+ val rdd = sc.parallelize(Seq(10.01, -0.01))
+ val buckets = Array(0.0, 4.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(0, 0)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithTwoBuckets") {
+ // Make sure that it works with two equally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6))
+ val buckets = Array(0.0, 5.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(3, 2)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithTwoBucketsAndNaN") {
+ // Make sure that it works with two equally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6, Double.NaN))
+ val buckets = Array(0.0, 5.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(3, 2)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithTwoUnevenBuckets") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6))
+ val buckets = Array(0.0, 5.0, 11.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(3, 2)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksMixedRangeWithTwoUnevenBuckets") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01))
+ val buckets = Array(0.0, 5.0, 11.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 3)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksMixedRangeWithFourUnevenBuckets") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+ 200.0, 200.1))
+ val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 2, 1, 3)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksMixedRangeWithUnevenBucketsAndNaN") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+ 200.0, 200.1, Double.NaN))
+ val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 2, 1, 3)
+ assert(histogramResults === expectedHistogramResults)
+ }
+ // Make sure this works with a NaN end bucket
+ test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRange") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+ 200.0, 200.1, Double.NaN))
+ val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 2, 1, 2, 3)
+ assert(histogramResults === expectedHistogramResults)
+ }
+ // Make sure this works with a NaN end bucket and an inifity
+ test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRangeAndInfity") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+ 200.0, 200.1, 1.0/0.0, -1.0/0.0, Double.NaN))
+ val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 2, 1, 2, 4)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksWithOutOfRangeWithInfiniteBuckets") {
+ // Verify that out of range works with two buckets
+ val rdd = sc.parallelize(Seq(10.01, -0.01, Double.NaN))
+ val buckets = Array(-1.0/0.0 , 0.0, 1.0/0.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(1, 1)
+ assert(histogramResults === expectedHistogramResults)
+ }
+ // Test the failure mode with an invalid bucket array
+ test("ThrowsExceptionOnInvalidBucketArray") {
+ val rdd = sc.parallelize(Seq(1.0))
+ // Empty array
+ intercept[IllegalArgumentException] {
+ val buckets = Array.empty[Double]
+ val result = rdd.histogram(buckets)
+ }
+ // Single element array
+ intercept[IllegalArgumentException] {
+ val buckets = Array(1.0)
+ val result = rdd.histogram(buckets)
+ }
+ }
+
+ // Test automatic histogram function
+ test("WorksWithoutBucketsBasic") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+ val (histogramBuckets, histogramResults) = rdd.histogram(1)
+ val expectedHistogramResults = Array(4)
+ val expectedHistogramBuckets = Array(1.0, 4.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+ // Test automatic histogram function with a single element
+ test("WorksWithoutBucketsBasicSingleElement") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1))
+ val (histogramBuckets, histogramResults) = rdd.histogram(1)
+ val expectedHistogramResults = Array(1)
+ val expectedHistogramBuckets = Array(1.0, 1.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+ // Test automatic histogram function with a single element
+ test("WorksWithoutBucketsBasicNoRange") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 1, 1, 1))
+ val (histogramBuckets, histogramResults) = rdd.histogram(1)
+ val expectedHistogramResults = Array(4)
+ val expectedHistogramBuckets = Array(1.0, 1.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+
+ test("WorksWithoutBucketsBasicTwo") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+ val (histogramBuckets, histogramResults) = rdd.histogram(2)
+ val expectedHistogramResults = Array(2, 2)
+ val expectedHistogramBuckets = Array(1.0, 2.5, 4.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+
+ test("WorksWithoutBucketsWithMoreRequestedThanElements") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2))
+ val (histogramBuckets, histogramResults) = rdd.histogram(10)
+ val expectedHistogramResults =
+ Array(1, 0, 0, 0, 0, 0, 0, 0, 0, 1)
+ val expectedHistogramBuckets =
+ Array(1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+
+ // Test the failure mode with an invalid RDD
+ test("ThrowsExceptionOnInvalidRDDs") {
+ // infinity
+ intercept[UnsupportedOperationException] {
+ val rdd = sc.parallelize(Seq(1, 1.0/0.0))
+ val result = rdd.histogram(1)
+ }
+ // NaN
+ intercept[UnsupportedOperationException] {
+ val rdd = sc.parallelize(Seq(1, Double.NaN))
+ val result = rdd.histogram(1)
+ }
+ // Empty
+ intercept[UnsupportedOperationException] {
+ val rdd: RDD[Double] = sc.parallelize(Seq())
+ val result = rdd.histogram(1)
+ }
+ }
+
+}