aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-06-23 10:07:16 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-06-23 10:26:53 -0700
commit78ffe164b33c6b11a2e511442605acd2f795a1b5 (patch)
treee10925c8e1d675c0a4ee95bd30414dd86ad13922 /core/src
parent0e0f9d3069039f03bbf5eefe3b0637c89fddf0f1 (diff)
downloadspark-78ffe164b33c6b11a2e511442605acd2f795a1b5.tar.gz
spark-78ffe164b33c6b11a2e511442605acd2f795a1b5.tar.bz2
spark-78ffe164b33c6b11a2e511442605acd2f795a1b5.zip
Clone the zero value for each key in foldByKey
The old version reused the object within each task, leading to overwriting of the object when a mutable type is used, which is expected to be common in fold. Conflicts: core/src/test/scala/spark/ShuffleSuite.scala
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala15
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala22
2 files changed, 34 insertions, 3 deletions
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index fa4bbfc76f..7630fe7803 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -1,5 +1,6 @@
package spark
+import java.nio.ByteBuffer
import java.util.{Date, HashMap => JHashMap}
import java.text.SimpleDateFormat
@@ -64,8 +65,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
throw new SparkException("Default partitioner cannot partition array keys.")
}
}
- val aggregator =
- new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
+ val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (self.partitioner == Some(partitioner)) {
self.mapPartitions(aggregator.combineValuesByKey(_), true)
} else if (mapSideCombine) {
@@ -97,7 +97,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* list concatenation, 0 for addition, or 1 for multiplication.).
*/
def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = {
- combineByKey[V]({v: V => func(zeroValue, v)}, func, func, partitioner)
+ // Serialize the zero value to a byte array so that we can get a new clone of it on each key
+ val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue)
+ val zeroArray = new Array[Byte](zeroBuffer.limit)
+ zeroBuffer.get(zeroArray)
+
+ // When deserializing, use a lazy val to create just one instance of the serializer per task
+ lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance()
+ def createZero() = cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray))
+
+ combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner)
}
/**
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index 1916885a73..0c1ec29f96 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -392,6 +392,28 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
assert(nonEmptyBlocks.size <= 4)
}
+ test("foldByKey") {
+ sc = new SparkContext("local", "test")
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.foldByKey(0)(_+_).collect()
+ assert(sums.toSet === Set((1, 7), (2, 1)))
+ }
+
+ test("foldByKey with mutable result type") {
+ sc = new SparkContext("local", "test")
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val bufs = pairs.mapValues(v => ArrayBuffer(v)).cache()
+ // Fold the values using in-place mutation
+ val sums = bufs.foldByKey(new ArrayBuffer[Int])(_ ++= _).collect()
+ assert(sums.toSet === Set((1, ArrayBuffer(1, 2, 3, 1)), (2, ArrayBuffer(1))))
+ // Check that the mutable objects in the original RDD were not changed
+ assert(bufs.collect().toSet === Set(
+ (1, ArrayBuffer(1)),
+ (1, ArrayBuffer(2)),
+ (1, ArrayBuffer(3)),
+ (1, ArrayBuffer(1)),
+ (2, ArrayBuffer(1))))
+ }
}
object ShuffleSuite {