aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Tustin <mtustin@handybook.com>2016-04-03 17:42:33 -0700
committerSean Owen <sowen@cloudera.com>2016-04-03 17:42:33 -0700
commit9023015f059327b3ce4a7eaf71e57ac77b84ad7b (patch)
tree6f1b7d71a36f8acc573a9e8cce31ddd05efa50fb
parentc238cd07448f94bbceb661daad90b6a6d597a846 (diff)
downloadspark-9023015f059327b3ce4a7eaf71e57ac77b84ad7b.tar.gz
spark-9023015f059327b3ce4a7eaf71e57ac77b84ad7b.tar.bz2
spark-9023015f059327b3ce4a7eaf71e57ac77b84ad7b.zip
[SPARK-14163][CORE] SumEvaluator and countApprox cannot reliably handle RDDs of size 1
## What changes were proposed in this pull request? This special cases 0 and 1 counts to avoid passing 0 degrees of freedom. ## How was this patch tested? Tests run successfully. New test added. ## Note: This recreates #11982 which was closed to due to non-updated diff. rxin srowen Commented there. This also adds tests, reworks the code to perform the special casing (based on srowen's comments), and adds equality machinery for BoundedDouble, as well as changing how it is transformed to string. Author: Marcin Tustin <mtustin@handybook.com> Author: Marcin Tustin <mtustin@handy.com> Closes #12016 from mtustin-handy/SPARK-14163.
-rw-r--r--core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala36
-rw-r--r--core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala107
3 files changed, 148 insertions, 13 deletions
diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala
index 48b9434153..d06b2c67d2 100644
--- a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala
+++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala
@@ -21,5 +21,23 @@ package org.apache.spark.partial
* A Double value with error bars and associated confidence.
*/
class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) {
+
override def toString(): String = "[%.3f, %.3f]".format(low, high)
+
+ override def hashCode: Int =
+ this.mean.hashCode ^ this.confidence.hashCode ^ this.low.hashCode ^ this.high.hashCode
+
+ /**
+ * Note that consistent with Double, any NaN value will make equality false
+ */
+ override def equals(that: Any): Boolean =
+ that match {
+ case that: BoundedDouble => {
+ this.mean == that.mean &&
+ this.confidence == that.confidence &&
+ this.low == that.low &&
+ this.high == that.high
+ }
+ case _ => false
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala
index 44295e5a1a..5fe3358316 100644
--- a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala
+++ b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala
@@ -29,8 +29,9 @@ import org.apache.spark.util.StatCounter
private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[StatCounter, BoundedDouble] {
+ // modified in merge
var outputsMerged = 0
- var counter = new StatCounter
+ val counter = new StatCounter
override def merge(outputId: Int, taskResult: StatCounter) {
outputsMerged += 1
@@ -40,30 +41,39 @@ private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double)
override def currentResult(): BoundedDouble = {
if (outputsMerged == totalOutputs) {
new BoundedDouble(counter.sum, 1.0, counter.sum, counter.sum)
- } else if (outputsMerged == 0) {
+ } else if (outputsMerged == 0 || counter.count == 0) {
new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
} else {
val p = outputsMerged.toDouble / totalOutputs
val meanEstimate = counter.mean
- val meanVar = counter.sampleVariance / counter.count
val countEstimate = (counter.count + 1 - p) / p
- val countVar = (counter.count + 1) * (1 - p) / (p * p)
val sumEstimate = meanEstimate * countEstimate
- val sumVar = (meanEstimate * meanEstimate * countVar) +
- (countEstimate * countEstimate * meanVar) +
- (meanVar * countVar)
- val sumStdev = math.sqrt(sumVar)
- val confFactor = {
- if (counter.count > 100) {
+
+ val meanVar = counter.sampleVariance / counter.count
+
+ // branch at this point because counter.count == 1 implies counter.sampleVariance == Nan
+ // and we don't want to ever return a bound of NaN
+ if (meanVar.isNaN || counter.count == 1) {
+ new BoundedDouble(sumEstimate, confidence, Double.NegativeInfinity, Double.PositiveInfinity)
+ } else {
+ val countVar = (counter.count + 1) * (1 - p) / (p * p)
+ val sumVar = (meanEstimate * meanEstimate * countVar) +
+ (countEstimate * countEstimate * meanVar) +
+ (meanVar * countVar)
+ val sumStdev = math.sqrt(sumVar)
+ val confFactor = if (counter.count > 100) {
new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2)
} else {
+ // note that if this goes to 0, TDistribution will throw an exception.
+ // Hence special casing 1 above.
val degreesOfFreedom = (counter.count - 1).toInt
new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2)
}
+
+ val low = sumEstimate - confFactor * sumStdev
+ val high = sumEstimate + confFactor * sumStdev
+ new BoundedDouble(sumEstimate, confidence, low, high)
}
- val low = sumEstimate - confFactor * sumStdev
- val high = sumEstimate + confFactor * sumStdev
- new BoundedDouble(sumEstimate, confidence, low, high)
}
}
}
diff --git a/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala b/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala
new file mode 100644
index 0000000000..a79f5b4d74
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.partial
+
+import org.apache.spark._
+import org.apache.spark.util.StatCounter
+
+class SumEvaluatorSuite extends SparkFunSuite with SharedSparkContext {
+
+ test("correct handling of count 1") {
+
+ // setup
+ val counter = new StatCounter(List(2.0))
+ // count of 10 because it's larger than 1,
+ // and 0.95 because that's the default
+ val evaluator = new SumEvaluator(10, 0.95)
+ // arbitrarily assign id 1
+ evaluator.merge(1, counter)
+
+ // execute
+ val res = evaluator.currentResult()
+ // 38.0 - 7.1E-15 because that's how the maths shakes out
+ val targetMean = 38.0 - 7.1E-15
+
+ // Sanity check that equality works on BoundedDouble
+ assert(new BoundedDouble(2.0, 0.95, 1.1, 1.2) == new BoundedDouble(2.0, 0.95, 1.1, 1.2))
+
+ // actual test
+ assert(res ==
+ new BoundedDouble(targetMean, 0.950, Double.NegativeInfinity, Double.PositiveInfinity))
+ }
+
+ test("correct handling of count 0") {
+
+ // setup
+ val counter = new StatCounter(List())
+ // count of 10 because it's larger than 0,
+ // and 0.95 because that's the default
+ val evaluator = new SumEvaluator(10, 0.95)
+ // arbitrarily assign id 1
+ evaluator.merge(1, counter)
+
+ // execute
+ val res = evaluator.currentResult()
+ // assert
+ assert(res == new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity))
+ }
+
+ test("correct handling of NaN") {
+
+ // setup
+ val counter = new StatCounter(List(1, Double.NaN, 2))
+ // count of 10 because it's larger than 0,
+ // and 0.95 because that's the default
+ val evaluator = new SumEvaluator(10, 0.95)
+ // arbitrarily assign id 1
+ evaluator.merge(1, counter)
+
+ // execute
+ val res = evaluator.currentResult()
+ // assert - note semantics of == in face of NaN
+ assert(res.mean.isNaN)
+ assert(res.confidence == 0.95)
+ assert(res.low == Double.NegativeInfinity)
+ assert(res.high == Double.PositiveInfinity)
+ }
+
+ test("correct handling of > 1 values") {
+
+ // setup
+ val counter = new StatCounter(List(1, 3, 2))
+ // count of 10 because it's larger than 0,
+ // and 0.95 because that's the default
+ val evaluator = new SumEvaluator(10, 0.95)
+ // arbitrarily assign id 1
+ evaluator.merge(1, counter)
+
+ // execute
+ val res = evaluator.currentResult()
+
+ // These vals because that's how the maths shakes out
+ val targetMean = 78.0
+ val targetLow = -117.617 + 2.732357258139473E-5
+ val targetHigh = 273.617 - 2.7323572624027292E-5
+ val target = new BoundedDouble(targetMean, 0.95, targetLow, targetHigh)
+
+
+ // check that values are within expected tolerance of expectation
+ assert(res == target)
+ }
+
+}