aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala')
-rw-r--r--core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala32
1 files changed, 13 insertions, 19 deletions
diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala
index 40b70baabc..8bb78123e3 100644
--- a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala
+++ b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala
@@ -22,36 +22,33 @@ import java.util.{HashMap => JHashMap}
import scala.collection.JavaConversions.mapAsScalaMap
import scala.collection.Map
import scala.collection.mutable.HashMap
+import scala.reflect.ClassTag
import cern.jet.stat.Probability
-import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
+
+import org.apache.spark.util.collection.OpenHashMap
/**
* An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval.
*/
-private[spark] class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Double)
- extends ApproximateEvaluator[OLMap[T], Map[T, BoundedDouble]] {
+private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, confidence: Double)
+ extends ApproximateEvaluator[OpenHashMap[T,Long], Map[T, BoundedDouble]] {
var outputsMerged = 0
- var sums = new OLMap[T] // Sum of counts for each key
+ var sums = new OpenHashMap[T,Long]() // Sum of counts for each key
- override def merge(outputId: Int, taskResult: OLMap[T]) {
+ override def merge(outputId: Int, taskResult: OpenHashMap[T,Long]) {
outputsMerged += 1
- val iter = taskResult.object2LongEntrySet.fastIterator()
- while (iter.hasNext) {
- val entry = iter.next()
- sums.put(entry.getKey, sums.getLong(entry.getKey) + entry.getLongValue)
+ taskResult.foreach { case (key, value) =>
+ sums.changeValue(key, value, _ + value)
}
}
override def currentResult(): Map[T, BoundedDouble] = {
if (outputsMerged == totalOutputs) {
val result = new JHashMap[T, BoundedDouble](sums.size)
- val iter = sums.object2LongEntrySet.fastIterator()
- while (iter.hasNext) {
- val entry = iter.next()
- val sum = entry.getLongValue()
- result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum)
+ sums.foreach { case (key, sum) =>
+ result(key) = new BoundedDouble(sum, 1.0, sum, sum)
}
result
} else if (outputsMerged == 0) {
@@ -60,16 +57,13 @@ private[spark] class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Dou
val p = outputsMerged.toDouble / totalOutputs
val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2)
val result = new JHashMap[T, BoundedDouble](sums.size)
- val iter = sums.object2LongEntrySet.fastIterator()
- while (iter.hasNext) {
- val entry = iter.next()
- val sum = entry.getLongValue
+ sums.foreach { case (key, sum) =>
val mean = (sum + 1 - p) / p
val variance = (sum + 1) * (1 - p) / (p * p)
val stdev = math.sqrt(variance)
val low = mean - confFactor * stdev
val high = mean + confFactor * stdev
- result(entry.getKey) = new BoundedDouble(mean, confidence, low, high)
+ result(key) = new BoundedDouble(mean, confidence, low, high)
}
result
}