aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala36
-rw-r--r--core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala107
3 files changed, 148 insertions, 13 deletions
diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala
index 48b9434153..d06b2c67d2 100644
--- a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala
+++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala
@@ -21,5 +21,23 @@ package org.apache.spark.partial
* A Double value with error bars and associated confidence.
*/
class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) {
+
override def toString(): String = "[%.3f, %.3f]".format(low, high)
+
+ override def hashCode: Int =
+ this.mean.hashCode ^ this.confidence.hashCode ^ this.low.hashCode ^ this.high.hashCode
+
+ /**
+ * Note that consistent with Double, any NaN value will make equality false
+ */
+ override def equals(that: Any): Boolean =
+ that match {
+ case that: BoundedDouble => {
+ this.mean == that.mean &&
+ this.confidence == that.confidence &&
+ this.low == that.low &&
+ this.high == that.high
+ }
+ case _ => false
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala
index 44295e5a1a..5fe3358316 100644
--- a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala
+++ b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala
@@ -29,8 +29,9 @@ import org.apache.spark.util.StatCounter
private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[StatCounter, BoundedDouble] {
+ // modified in merge
var outputsMerged = 0
- var counter = new StatCounter
+ val counter = new StatCounter
override def merge(outputId: Int, taskResult: StatCounter) {
outputsMerged += 1
@@ -40,30 +41,39 @@ private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double)
override def currentResult(): BoundedDouble = {
if (outputsMerged == totalOutputs) {
new BoundedDouble(counter.sum, 1.0, counter.sum, counter.sum)
- } else if (outputsMerged == 0) {
+ } else if (outputsMerged == 0 || counter.count == 0) {
new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
} else {
val p = outputsMerged.toDouble / totalOutputs
val meanEstimate = counter.mean
- val meanVar = counter.sampleVariance / counter.count
val countEstimate = (counter.count + 1 - p) / p
- val countVar = (counter.count + 1) * (1 - p) / (p * p)
val sumEstimate = meanEstimate * countEstimate
- val sumVar = (meanEstimate * meanEstimate * countVar) +
- (countEstimate * countEstimate * meanVar) +
- (meanVar * countVar)
- val sumStdev = math.sqrt(sumVar)
- val confFactor = {
- if (counter.count > 100) {
+
+ val meanVar = counter.sampleVariance / counter.count
+
+ // branch at this point because counter.count == 1 implies counter.sampleVariance == Nan
+ // and we don't want to ever return a bound of NaN
+ if (meanVar.isNaN || counter.count == 1) {
+ new BoundedDouble(sumEstimate, confidence, Double.NegativeInfinity, Double.PositiveInfinity)
+ } else {
+ val countVar = (counter.count + 1) * (1 - p) / (p * p)
+ val sumVar = (meanEstimate * meanEstimate * countVar) +
+ (countEstimate * countEstimate * meanVar) +
+ (meanVar * countVar)
+ val sumStdev = math.sqrt(sumVar)
+ val confFactor = if (counter.count > 100) {
new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2)
} else {
+ // note that if this goes to 0, TDistribution will throw an exception.
+ // Hence special casing 1 above.
val degreesOfFreedom = (counter.count - 1).toInt
new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2)
}
+
+ val low = sumEstimate - confFactor * sumStdev
+ val high = sumEstimate + confFactor * sumStdev
+ new BoundedDouble(sumEstimate, confidence, low, high)
}
- val low = sumEstimate - confFactor * sumStdev
- val high = sumEstimate + confFactor * sumStdev
- new BoundedDouble(sumEstimate, confidence, low, high)
}
}
}
diff --git a/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala b/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala
new file mode 100644
index 0000000000..a79f5b4d74
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.partial
+
+import org.apache.spark._
+import org.apache.spark.util.StatCounter
+
+class SumEvaluatorSuite extends SparkFunSuite with SharedSparkContext {
+
+ test("correct handling of count 1") {
+
+ // setup
+ val counter = new StatCounter(List(2.0))
+ // count of 10 because it's larger than 1,
+ // and 0.95 because that's the default
+ val evaluator = new SumEvaluator(10, 0.95)
+ // arbitrarily assign id 1
+ evaluator.merge(1, counter)
+
+ // execute
+ val res = evaluator.currentResult()
+ // 38.0 - 7.1E-15 because that's how the maths shakes out
+ val targetMean = 38.0 - 7.1E-15
+
+ // Sanity check that equality works on BoundedDouble
+ assert(new BoundedDouble(2.0, 0.95, 1.1, 1.2) == new BoundedDouble(2.0, 0.95, 1.1, 1.2))
+
+ // actual test
+ assert(res ==
+ new BoundedDouble(targetMean, 0.950, Double.NegativeInfinity, Double.PositiveInfinity))
+ }
+
+ test("correct handling of count 0") {
+
+ // setup
+ val counter = new StatCounter(List())
+ // count of 10 because it's larger than 0,
+ // and 0.95 because that's the default
+ val evaluator = new SumEvaluator(10, 0.95)
+ // arbitrarily assign id 1
+ evaluator.merge(1, counter)
+
+ // execute
+ val res = evaluator.currentResult()
+ // assert
+ assert(res == new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity))
+ }
+
+ test("correct handling of NaN") {
+
+ // setup
+ val counter = new StatCounter(List(1, Double.NaN, 2))
+ // count of 10 because it's larger than 0,
+ // and 0.95 because that's the default
+ val evaluator = new SumEvaluator(10, 0.95)
+ // arbitrarily assign id 1
+ evaluator.merge(1, counter)
+
+ // execute
+ val res = evaluator.currentResult()
+ // assert - note semantics of == in face of NaN
+ assert(res.mean.isNaN)
+ assert(res.confidence == 0.95)
+ assert(res.low == Double.NegativeInfinity)
+ assert(res.high == Double.PositiveInfinity)
+ }
+
+ test("correct handling of > 1 values") {
+
+ // setup
+ val counter = new StatCounter(List(1, 3, 2))
+ // count of 10 because it's larger than 0,
+ // and 0.95 because that's the default
+ val evaluator = new SumEvaluator(10, 0.95)
+ // arbitrarily assign id 1
+ evaluator.merge(1, counter)
+
+ // execute
+ val res = evaluator.currentResult()
+
+ // These vals because that's how the maths shakes out
+ val targetMean = 78.0
+ val targetLow = -117.617 + 2.732357258139473E-5
+ val targetHigh = 273.617 - 2.7323572624027292E-5
+ val target = new BoundedDouble(targetMean, 0.95, targetLow, targetHigh)
+
+
+ // check that values are within expected tolerance of expectation
+ assert(res == target)
+ }
+
+}