diff options
author | Aaron Davidson <aaron@databricks.com> | 2014-10-21 13:15:29 -0700 |
---|---|---|
committer | Andrew Or <andrewor14@gmail.com> | 2014-10-21 13:15:29 -0700 |
commit | 5fdaf52a9df21cac69e2a4612aeb4e760e4424e7 (patch) | |
tree | 2ac64dfaf90aab9c925ffcd04e59fc059c0ed408 /core | |
parent | 1a623b2e163da3a9112cb9b68bda22b9e398ed5c (diff) | |
download | spark-5fdaf52a9df21cac69e2a4612aeb4e760e4424e7.tar.gz spark-5fdaf52a9df21cac69e2a4612aeb4e760e4424e7.tar.bz2 spark-5fdaf52a9df21cac69e2a4612aeb4e760e4424e7.zip |
[SPARK-3994] Use standard Aggregator code path for countByKey and countByValue
See [JIRA](https://issues.apache.org/jira/browse/SPARK-3994) for more information. Also adds
a note which warns against using these methods.
Author: Aaron Davidson <aaron@databricks.com>
Closes #2839 from aarondav/countByKey and squashes the following commits:
d6fdb2a [Aaron Davidson] Respond to comments
e1f06d3 [Aaron Davidson] [SPARK-3994] Use standard Aggregator code path for countByKey and countByValue
Diffstat (limited to 'core')
-rw-r--r-- | core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala | 11 | ||||
-rw-r--r-- | core/src/main/scala/org/apache/spark/rdd/RDD.scala | 31 |
2 files changed, 16 insertions, 26 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index ac96de86dd..da89f634ab 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -315,8 +315,15 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) @deprecated("Use reduceByKeyLocally", "1.0.0") def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func) - /** Count the number of elements for each key, and return the result to the master as a Map. */ - def countByKey(): Map[K, Long] = self.map(_._1).countByValue() + /** + * Count the number of elements for each key, collecting the results to a local Map. + * + * Note that this method should only be used if the resulting map is expected to be small, as + * the whole thing is loaded into the driver's memory. + * To handle very large results, consider using rdd.mapValues(_ => 1L).reduceByKey(_ + _), which + * returns an RDD[T, Long] instead of a map. + */ + def countByKey(): Map[K, Long] = self.mapValues(_ => 1L).reduceByKey(_ + _).collect().toMap /** * :: Experimental :: diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 71cabf61d4..b7f125d01d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -927,32 +927,15 @@ abstract class RDD[T: ClassTag]( } /** - * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final - * combine step happens locally on the master, equivalent to running a single reduce task. + * Return the count of each unique value in this RDD as a local map of (value, count) pairs. + * + * Note that this method should only be used if the resulting map is expected to be small, as + * the whole thing is loaded into the driver's memory. + * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which + * returns an RDD[T, Long] instead of a map. */ def countByValue()(implicit ord: Ordering[T] = null): Map[T, Long] = { - if (elementClassTag.runtimeClass.isArray) { - throw new SparkException("countByValue() does not support arrays") - } - // TODO: This should perhaps be distributed by default. - val countPartition = (iter: Iterator[T]) => { - val map = new OpenHashMap[T,Long] - iter.foreach { - t => map.changeValue(t, 1L, _ + 1L) - } - Iterator(map) - }: Iterator[OpenHashMap[T,Long]] - val mergeMaps = (m1: OpenHashMap[T,Long], m2: OpenHashMap[T,Long]) => { - m2.foreach { case (key, value) => - m1.changeValue(key, value, _ + value) - } - m1 - }: OpenHashMap[T,Long] - val myResult = mapPartitions(countPartition).reduce(mergeMaps) - // Convert to a Scala mutable map - val mutableResult = scala.collection.mutable.Map[T,Long]() - myResult.foreach { case (k, v) => mutableResult.put(k, v) } - mutableResult + map(value => (value, null)).countByKey() } /** |