diff options
author | Matei Zaharia <matei@eecs.berkeley.edu> | 2013-06-23 10:07:16 -0700 |
---|---|---|
committer | Matei Zaharia <matei@eecs.berkeley.edu> | 2013-06-23 10:26:53 -0700 |
commit | 78ffe164b33c6b11a2e511442605acd2f795a1b5 (patch) | |
tree | e10925c8e1d675c0a4ee95bd30414dd86ad13922 /core | |
parent | 0e0f9d3069039f03bbf5eefe3b0637c89fddf0f1 (diff) | |
download | spark-78ffe164b33c6b11a2e511442605acd2f795a1b5.tar.gz spark-78ffe164b33c6b11a2e511442605acd2f795a1b5.tar.bz2 spark-78ffe164b33c6b11a2e511442605acd2f795a1b5.zip |
Clone the zero value for each key in foldByKey
The old version reused the object within each task, leading to
overwriting of the object when a mutable type is used, which is expected
to be common in fold.
Conflicts:
core/src/test/scala/spark/ShuffleSuite.scala
Diffstat (limited to 'core')
-rw-r--r-- | core/src/main/scala/spark/PairRDDFunctions.scala | 15 | ||||
-rw-r--r-- | core/src/test/scala/spark/ShuffleSuite.scala | 22 |
2 files changed, 34 insertions, 3 deletions
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index fa4bbfc76f..7630fe7803 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -1,5 +1,6 @@ package spark +import java.nio.ByteBuffer import java.util.{Date, HashMap => JHashMap} import java.text.SimpleDateFormat @@ -64,8 +65,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( throw new SparkException("Default partitioner cannot partition array keys.") } } - val aggregator = - new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) + val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) if (self.partitioner == Some(partitioner)) { self.mapPartitions(aggregator.combineValuesByKey(_), true) } else if (mapSideCombine) { @@ -97,7 +97,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * list concatenation, 0 for addition, or 1 for multiplication.). */ def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = { - combineByKey[V]({v: V => func(zeroValue, v)}, func, func, partitioner) + // Serialize the zero value to a byte array so that we can get a new clone of it on each key + val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue) + val zeroArray = new Array[Byte](zeroBuffer.limit) + zeroBuffer.get(zeroArray) + + // When deserializing, use a lazy val to create just one instance of the serializer per task + lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance() + def createZero() = cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray)) + + combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner) } /** diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 1916885a73..0c1ec29f96 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -392,6 +392,28 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { assert(nonEmptyBlocks.size <= 4) } + test("foldByKey") { + sc = new SparkContext("local", "test") + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.foldByKey(0)(_+_).collect() + assert(sums.toSet === Set((1, 7), (2, 1))) + } + + test("foldByKey with mutable result type") { + sc = new SparkContext("local", "test") + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val bufs = pairs.mapValues(v => ArrayBuffer(v)).cache() + // Fold the values using in-place mutation + val sums = bufs.foldByKey(new ArrayBuffer[Int])(_ ++= _).collect() + assert(sums.toSet === Set((1, ArrayBuffer(1, 2, 3, 1)), (2, ArrayBuffer(1)))) + // Check that the mutable objects in the original RDD were not changed + assert(bufs.collect().toSet === Set( + (1, ArrayBuffer(1)), + (1, ArrayBuffer(2)), + (1, ArrayBuffer(3)), + (1, ArrayBuffer(1)), + (2, ArrayBuffer(1)))) + } } object ShuffleSuite { |