diff options
10 files changed, 203 insertions, 332 deletions
diff --git a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala index 637492a975..5a5bd7fbbe 100644 --- a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala @@ -17,21 +17,18 @@ package org.apache.spark.partial -import org.apache.commons.math3.distribution.NormalDistribution +import org.apache.commons.math3.distribution.{PascalDistribution, PoissonDistribution} /** * An ApproximateEvaluator for counts. - * - * TODO: There's currently a lot of shared code between this and GroupedCountEvaluator. It might - * be best to make this a special case of GroupedCountEvaluator with one group. */ private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double) extends ApproximateEvaluator[Long, BoundedDouble] { - var outputsMerged = 0 - var sum: Long = 0 + private var outputsMerged = 0 + private var sum: Long = 0 - override def merge(outputId: Int, taskResult: Long) { + override def merge(outputId: Int, taskResult: Long): Unit = { outputsMerged += 1 sum += taskResult } @@ -39,18 +36,40 @@ private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double) override def currentResult(): BoundedDouble = { if (outputsMerged == totalOutputs) { new BoundedDouble(sum, 1.0, sum, sum) - } else if (outputsMerged == 0) { - new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) + } else if (outputsMerged == 0 || sum == 0) { + new BoundedDouble(0, 0.0, 0.0, Double.PositiveInfinity) } else { val p = outputsMerged.toDouble / totalOutputs - val mean = (sum + 1 - p) / p - val variance = (sum + 1) * (1 - p) / (p * p) - val stdev = math.sqrt(variance) - val confFactor = new NormalDistribution(). - inverseCumulativeProbability(1 - (1 - confidence) / 2) - val low = mean - confFactor * stdev - val high = mean + confFactor * stdev - new BoundedDouble(mean, confidence, low, high) + CountEvaluator.bound(confidence, sum, p) } } } + +private[partial] object CountEvaluator { + + def bound(confidence: Double, sum: Long, p: Double): BoundedDouble = { + // Let the total count be N. A fraction p has been counted already, with sum 'sum', + // as if each element from the total data set had been seen with probability p. + val dist = + if (sum <= 10000) { + // The remaining count, k=N-sum, may be modeled as negative binomial (aka Pascal), + // where there have been 'sum' successes of probability p already. (There are several + // conventions, but this is the one followed by Commons Math3.) + new PascalDistribution(sum.toInt, p) + } else { + // For large 'sum' (certainly, > Int.MaxValue!), use a Poisson approximation, which has + // a different interpretation. "sum" elements have been observed having scanned a fraction + // p of the data. This suggests data is counted at a rate of sum / p across the whole data + // set. The total expected count from the rest is distributed as + // (1-p) Poisson(sum / p) = Poisson(sum*(1-p)/p) + new PoissonDistribution(sum * (1 - p) / p) + } + // Not quite symmetric; calculate interval straight from discrete distribution + val low = dist.inverseCumulativeProbability((1 - confidence) / 2) + val high = dist.inverseCumulativeProbability((1 + confidence) / 2) + // Add 'sum' to each because distribution is just of remaining count, not observed + new BoundedDouble(sum + dist.getNumericalMean, confidence, sum + low, sum + high) + } + + +} diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala index 5afce75680..d2b4187df5 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala @@ -17,15 +17,10 @@ package org.apache.spark.partial -import java.util.{HashMap => JHashMap} - -import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.HashMap import scala.reflect.ClassTag -import org.apache.commons.math3.distribution.NormalDistribution - import org.apache.spark.util.collection.OpenHashMap /** @@ -34,10 +29,10 @@ import org.apache.spark.util.collection.OpenHashMap private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, confidence: Double) extends ApproximateEvaluator[OpenHashMap[T, Long], Map[T, BoundedDouble]] { - var outputsMerged = 0 - var sums = new OpenHashMap[T, Long]() // Sum of counts for each key + private var outputsMerged = 0 + private val sums = new OpenHashMap[T, Long]() // Sum of counts for each key - override def merge(outputId: Int, taskResult: OpenHashMap[T, Long]) { + override def merge(outputId: Int, taskResult: OpenHashMap[T, Long]): Unit = { outputsMerged += 1 taskResult.foreach { case (key, value) => sums.changeValue(key, value, _ + value) @@ -46,27 +41,12 @@ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, conf override def currentResult(): Map[T, BoundedDouble] = { if (outputsMerged == totalOutputs) { - val result = new JHashMap[T, BoundedDouble](sums.size) - sums.foreach { case (key, sum) => - result.put(key, new BoundedDouble(sum, 1.0, sum, sum)) - } - result.asScala + sums.map { case (key, sum) => (key, new BoundedDouble(sum, 1.0, sum, sum)) }.toMap } else if (outputsMerged == 0) { new HashMap[T, BoundedDouble] } else { val p = outputsMerged.toDouble / totalOutputs - val confFactor = new NormalDistribution(). - inverseCumulativeProbability(1 - (1 - confidence) / 2) - val result = new JHashMap[T, BoundedDouble](sums.size) - sums.foreach { case (key, sum) => - val mean = (sum + 1 - p) / p - val variance = (sum + 1) * (1 - p) / (p * p) - val stdev = math.sqrt(variance) - val low = mean - confFactor * stdev - val high = mean + confFactor * stdev - result.put(key, new BoundedDouble(mean, confidence, low, high)) - } - result.asScala + sums.map { case (key, sum) => (key, CountEvaluator.bound(confidence, sum, p)) }.toMap } } } diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala deleted file mode 100644 index a164040684..0000000000 --- a/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.partial - -import java.util.{HashMap => JHashMap} - -import scala.collection.JavaConverters._ -import scala.collection.Map -import scala.collection.mutable.HashMap - -import org.apache.spark.util.StatCounter - -/** - * An ApproximateEvaluator for means by key. Returns a map of key to confidence interval. - */ -private[spark] class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] { - - var outputsMerged = 0 - var sums = new JHashMap[T, StatCounter] // Sum of counts for each key - - override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) { - outputsMerged += 1 - val iter = taskResult.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val old = sums.get(entry.getKey) - if (old != null) { - old.merge(entry.getValue) - } else { - sums.put(entry.getKey, entry.getValue) - } - } - } - - override def currentResult(): Map[T, BoundedDouble] = { - if (outputsMerged == totalOutputs) { - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val mean = entry.getValue.mean - result.put(entry.getKey, new BoundedDouble(mean, 1.0, mean, mean)) - } - result.asScala - } else if (outputsMerged == 0) { - new HashMap[T, BoundedDouble] - } else { - val studentTCacher = new StudentTCacher(confidence) - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val counter = entry.getValue - val mean = counter.mean - val stdev = math.sqrt(counter.sampleVariance / counter.count) - val confFactor = studentTCacher.get(counter.count) - val low = mean - confFactor * stdev - val high = mean + confFactor * stdev - result.put(entry.getKey, new BoundedDouble(mean, confidence, low, high)) - } - result.asScala - } - } -} diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala deleted file mode 100644 index 54a1beab35..0000000000 --- a/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.partial - -import java.util.{HashMap => JHashMap} - -import scala.collection.JavaConverters._ -import scala.collection.Map -import scala.collection.mutable.HashMap - -import org.apache.spark.util.StatCounter - -/** - * An ApproximateEvaluator for sums by key. Returns a map of key to confidence interval. - */ -private[spark] class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] { - - var outputsMerged = 0 - var sums = new JHashMap[T, StatCounter] // Sum of counts for each key - - override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) { - outputsMerged += 1 - val iter = taskResult.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val old = sums.get(entry.getKey) - if (old != null) { - old.merge(entry.getValue) - } else { - sums.put(entry.getKey, entry.getValue) - } - } - } - - override def currentResult(): Map[T, BoundedDouble] = { - if (outputsMerged == totalOutputs) { - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val sum = entry.getValue.sum - result.put(entry.getKey, new BoundedDouble(sum, 1.0, sum, sum)) - } - result.asScala - } else if (outputsMerged == 0) { - new HashMap[T, BoundedDouble] - } else { - val p = outputsMerged.toDouble / totalOutputs - val studentTCacher = new StudentTCacher(confidence) - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val counter = entry.getValue - val meanEstimate = counter.mean - val meanVar = counter.sampleVariance / counter.count - val countEstimate = (counter.count + 1 - p) / p - val countVar = (counter.count + 1) * (1 - p) / (p * p) - val sumEstimate = meanEstimate * countEstimate - val sumVar = (meanEstimate * meanEstimate * countVar) + - (countEstimate * countEstimate * meanVar) + - (meanVar * countVar) - val sumStdev = math.sqrt(sumVar) - val confFactor = studentTCacher.get(counter.count) - val low = sumEstimate - confFactor * sumStdev - val high = sumEstimate + confFactor * sumStdev - result.put(entry.getKey, new BoundedDouble(sumEstimate, confidence, low, high)) - } - result.asScala - } - } -} diff --git a/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala index 787a21a61f..3fb2d30a80 100644 --- a/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala @@ -27,10 +27,10 @@ import org.apache.spark.util.StatCounter private[spark] class MeanEvaluator(totalOutputs: Int, confidence: Double) extends ApproximateEvaluator[StatCounter, BoundedDouble] { - var outputsMerged = 0 - var counter = new StatCounter + private var outputsMerged = 0 + private val counter = new StatCounter() - override def merge(outputId: Int, taskResult: StatCounter) { + override def merge(outputId: Int, taskResult: StatCounter): Unit = { outputsMerged += 1 counter.merge(taskResult) } @@ -38,19 +38,24 @@ private[spark] class MeanEvaluator(totalOutputs: Int, confidence: Double) override def currentResult(): BoundedDouble = { if (outputsMerged == totalOutputs) { new BoundedDouble(counter.mean, 1.0, counter.mean, counter.mean) - } else if (outputsMerged == 0) { + } else if (outputsMerged == 0 || counter.count == 0) { new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) + } else if (counter.count == 1) { + new BoundedDouble(counter.mean, confidence, Double.NegativeInfinity, Double.PositiveInfinity) } else { val mean = counter.mean val stdev = math.sqrt(counter.sampleVariance / counter.count) - val confFactor = { - if (counter.count > 100) { - new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2) + val confFactor = if (counter.count > 100) { + // For large n, the normal distribution is a good approximation to t-distribution + new NormalDistribution().inverseCumulativeProbability((1 + confidence) / 2) } else { + // t-distribution describes distribution of actual population mean + // note that if this goes to 0, TDistribution will throw an exception. + // Hence special casing 1 above. val degreesOfFreedom = (counter.count - 1).toInt - new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2) + new TDistribution(degreesOfFreedom).inverseCumulativeProbability((1 + confidence) / 2) } - } + // Symmetric, so confidence interval is symmetric about mean of distribution val low = mean - confFactor * stdev val high = mean + confFactor * stdev new BoundedDouble(mean, confidence, low, high) diff --git a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala b/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala deleted file mode 100644 index 55acb9ca64..0000000000 --- a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.partial - -import org.apache.commons.math3.distribution.{NormalDistribution, TDistribution} - -/** - * A utility class for caching Student's T distribution values for a given confidence level - * and various sample sizes. This is used by the MeanEvaluator to efficiently calculate - * confidence intervals for many keys. - */ -private[spark] class StudentTCacher(confidence: Double) { - - val NORMAL_APPROX_SAMPLE_SIZE = 100 // For samples bigger than this, use Gaussian approximation - - val normalApprox = new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2) - val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0) - - def get(sampleSize: Long): Double = { - if (sampleSize >= NORMAL_APPROX_SAMPLE_SIZE) { - normalApprox - } else { - val size = sampleSize.toInt - if (cache(size) < 0) { - val tDist = new TDistribution(size - 1) - cache(size) = tDist.inverseCumulativeProbability(1 - (1 - confidence) / 2) - } - cache(size) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala index 5fe3358316..1988052b73 100644 --- a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala @@ -30,10 +30,10 @@ private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double) extends ApproximateEvaluator[StatCounter, BoundedDouble] { // modified in merge - var outputsMerged = 0 - val counter = new StatCounter + private var outputsMerged = 0 + private val counter = new StatCounter() - override def merge(outputId: Int, taskResult: StatCounter) { + override def merge(outputId: Int, taskResult: StatCounter): Unit = { outputsMerged += 1 counter.merge(taskResult) } @@ -45,34 +45,45 @@ private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double) new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) } else { val p = outputsMerged.toDouble / totalOutputs + // Expected value of unobserved is presumed equal to that of the observed data val meanEstimate = counter.mean - val countEstimate = (counter.count + 1 - p) / p + // Expected size of rest of the data is proportional + val countEstimate = counter.count * (1 - p) / p + // Expected sum is simply their product val sumEstimate = meanEstimate * countEstimate + // Variance of unobserved data is presumed equal to that of the observed data val meanVar = counter.sampleVariance / counter.count - // branch at this point because counter.count == 1 implies counter.sampleVariance == Nan + // branch at this point because count == 1 implies counter.sampleVariance == Nan // and we don't want to ever return a bound of NaN if (meanVar.isNaN || counter.count == 1) { - new BoundedDouble(sumEstimate, confidence, Double.NegativeInfinity, Double.PositiveInfinity) + // add sum because estimate is of unobserved data sum + new BoundedDouble( + counter.sum + sumEstimate, confidence, Double.NegativeInfinity, Double.PositiveInfinity) } else { - val countVar = (counter.count + 1) * (1 - p) / (p * p) + // See CountEvaluator. Variance of population count here follows from negative binomial + val countVar = counter.count * (1 - p) / (p * p) + // Var(Sum) = Var(Mean*Count) = + // [E(Mean)]^2 * Var(Count) + [E(Count)]^2 * Var(Mean) + Var(Mean) * Var(Count) val sumVar = (meanEstimate * meanEstimate * countVar) + (countEstimate * countEstimate * meanVar) + (meanVar * countVar) val sumStdev = math.sqrt(sumVar) val confFactor = if (counter.count > 100) { - new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2) + new NormalDistribution().inverseCumulativeProbability((1 + confidence) / 2) } else { // note that if this goes to 0, TDistribution will throw an exception. // Hence special casing 1 above. val degreesOfFreedom = (counter.count - 1).toInt - new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2) + new TDistribution(degreesOfFreedom).inverseCumulativeProbability((1 + confidence) / 2) } - + // Symmetric, so confidence interval is symmetric about mean of distribution val low = sumEstimate - confFactor * sumStdev val high = sumEstimate + confFactor * sumStdev - new BoundedDouble(sumEstimate, confidence, low, high) + // add sum because estimate is of unobserved data sum + new BoundedDouble( + counter.sum + sumEstimate, confidence, counter.sum + low, counter.sum + high) } } } diff --git a/core/src/test/scala/org/apache/spark/partial/CountEvaluatorSuite.scala b/core/src/test/scala/org/apache/spark/partial/CountEvaluatorSuite.scala new file mode 100644 index 0000000000..da3256bd88 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/partial/CountEvaluatorSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.partial + +import org.apache.spark.SparkFunSuite + +class CountEvaluatorSuite extends SparkFunSuite { + + test("test count 0") { + val evaluator = new CountEvaluator(10, 0.95) + assert(new BoundedDouble(0.0, 0.0, 0.0, Double.PositiveInfinity) == evaluator.currentResult()) + evaluator.merge(1, 0) + assert(new BoundedDouble(0.0, 0.0, 0.0, Double.PositiveInfinity) == evaluator.currentResult()) + } + + test("test count >= 1") { + val evaluator = new CountEvaluator(10, 0.95) + evaluator.merge(1, 1) + assert(new BoundedDouble(10.0, 0.95, 1.0, 36.0) == evaluator.currentResult()) + evaluator.merge(1, 3) + assert(new BoundedDouble(20.0, 0.95, 7.0, 41.0) == evaluator.currentResult()) + evaluator.merge(1, 8) + assert(new BoundedDouble(40.0, 0.95, 24.0, 61.0) == evaluator.currentResult()) + (4 to 10).foreach(_ => evaluator.merge(1, 10)) + assert(new BoundedDouble(82.0, 1.0, 82.0, 82.0) == evaluator.currentResult()) + } + +} diff --git a/core/src/test/scala/org/apache/spark/partial/MeanEvaluatorSuite.scala b/core/src/test/scala/org/apache/spark/partial/MeanEvaluatorSuite.scala new file mode 100644 index 0000000000..eaa1262b41 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/partial/MeanEvaluatorSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.partial + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.StatCounter + +class MeanEvaluatorSuite extends SparkFunSuite { + + test("test count 0") { + val evaluator = new MeanEvaluator(10, 0.95) + assert(new BoundedDouble(0.0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) == + evaluator.currentResult()) + evaluator.merge(1, new StatCounter()) + assert(new BoundedDouble(0.0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) == + evaluator.currentResult()) + evaluator.merge(1, new StatCounter(Seq(0.0))) + assert(new BoundedDouble(0.0, 0.95, Double.NegativeInfinity, Double.PositiveInfinity) == + evaluator.currentResult()) + } + + test("test count 1") { + val evaluator = new MeanEvaluator(10, 0.95) + evaluator.merge(1, new StatCounter(Seq(1.0))) + assert(new BoundedDouble(1.0, 0.95, Double.NegativeInfinity, Double.PositiveInfinity) == + evaluator.currentResult()) + } + + test("test count > 1") { + val evaluator = new MeanEvaluator(10, 0.95) + evaluator.merge(1, new StatCounter(Seq(1.0))) + evaluator.merge(1, new StatCounter(Seq(3.0))) + assert(new BoundedDouble(2.0, 0.95, -10.706204736174746, 14.706204736174746) == + evaluator.currentResult()) + evaluator.merge(1, new StatCounter(Seq(8.0))) + assert(new BoundedDouble(4.0, 0.95, -4.9566858949231225, 12.956685894923123) == + evaluator.currentResult()) + (4 to 10).foreach(_ => evaluator.merge(1, new StatCounter(Seq(9.0)))) + assert(new BoundedDouble(7.5, 1.0, 7.5, 7.5) == evaluator.currentResult()) + } + +} diff --git a/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala b/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala index a79f5b4d74..e212db7362 100644 --- a/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala @@ -17,61 +17,34 @@ package org.apache.spark.partial -import org.apache.spark._ +import org.apache.spark.SparkFunSuite import org.apache.spark.util.StatCounter -class SumEvaluatorSuite extends SparkFunSuite with SharedSparkContext { +class SumEvaluatorSuite extends SparkFunSuite { test("correct handling of count 1") { + // sanity check: + assert(new BoundedDouble(2.0, 0.95, 1.1, 1.2) == new BoundedDouble(2.0, 0.95, 1.1, 1.2)) - // setup - val counter = new StatCounter(List(2.0)) // count of 10 because it's larger than 1, // and 0.95 because that's the default val evaluator = new SumEvaluator(10, 0.95) // arbitrarily assign id 1 - evaluator.merge(1, counter) - - // execute - val res = evaluator.currentResult() - // 38.0 - 7.1E-15 because that's how the maths shakes out - val targetMean = 38.0 - 7.1E-15 - - // Sanity check that equality works on BoundedDouble - assert(new BoundedDouble(2.0, 0.95, 1.1, 1.2) == new BoundedDouble(2.0, 0.95, 1.1, 1.2)) - - // actual test - assert(res == - new BoundedDouble(targetMean, 0.950, Double.NegativeInfinity, Double.PositiveInfinity)) + evaluator.merge(1, new StatCounter(Seq(2.0))) + assert(new BoundedDouble(20.0, 0.95, Double.NegativeInfinity, Double.PositiveInfinity) == + evaluator.currentResult()) } test("correct handling of count 0") { - - // setup - val counter = new StatCounter(List()) - // count of 10 because it's larger than 0, - // and 0.95 because that's the default val evaluator = new SumEvaluator(10, 0.95) - // arbitrarily assign id 1 - evaluator.merge(1, counter) - - // execute - val res = evaluator.currentResult() - // assert - assert(res == new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)) + evaluator.merge(1, new StatCounter()) + assert(new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) == + evaluator.currentResult()) } test("correct handling of NaN") { - - // setup - val counter = new StatCounter(List(1, Double.NaN, 2)) - // count of 10 because it's larger than 0, - // and 0.95 because that's the default val evaluator = new SumEvaluator(10, 0.95) - // arbitrarily assign id 1 - evaluator.merge(1, counter) - - // execute + evaluator.merge(1, new StatCounter(Seq(1, Double.NaN, 2))) val res = evaluator.currentResult() // assert - note semantics of == in face of NaN assert(res.mean.isNaN) @@ -81,27 +54,24 @@ class SumEvaluatorSuite extends SparkFunSuite with SharedSparkContext { } test("correct handling of > 1 values") { - - // setup - val counter = new StatCounter(List(1, 3, 2)) - // count of 10 because it's larger than 0, - // and 0.95 because that's the default val evaluator = new SumEvaluator(10, 0.95) - // arbitrarily assign id 1 - evaluator.merge(1, counter) - - // execute + evaluator.merge(1, new StatCounter(Seq(1.0, 3.0, 2.0))) val res = evaluator.currentResult() + assert(new BoundedDouble(60.0, 0.95, -101.7362525347778, 221.7362525347778) == + evaluator.currentResult()) + } - // These vals because that's how the maths shakes out - val targetMean = 78.0 - val targetLow = -117.617 + 2.732357258139473E-5 - val targetHigh = 273.617 - 2.7323572624027292E-5 - val target = new BoundedDouble(targetMean, 0.95, targetLow, targetHigh) - - - // check that values are within expected tolerance of expectation - assert(res == target) + test("test count > 1") { + val evaluator = new SumEvaluator(10, 0.95) + evaluator.merge(1, new StatCounter().merge(1.0)) + evaluator.merge(1, new StatCounter().merge(3.0)) + assert(new BoundedDouble(20.0, 0.95, -186.4513905077019, 226.4513905077019) == + evaluator.currentResult()) + evaluator.merge(1, new StatCounter().merge(8.0)) + assert(new BoundedDouble(40.0, 0.95, -72.75723361226733, 152.75723361226733) == + evaluator.currentResult()) + (4 to 10).foreach(_ => evaluator.merge(1, new StatCounter().merge(9.0))) + assert(new BoundedDouble(75.0, 1.0, 75.0, 75.0) == evaluator.currentResult()) } } |