aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorReynold Xin <reynoldx@gmail.com>2013-08-19 00:40:43 -0700
committerReynold Xin <reynoldx@gmail.com>2013-08-19 00:40:43 -0700
commit71d705a66eb8782e5cd5c77853fdd99fd8155334 (patch)
tree16c14b022fec9ddce284ba157b99512400e296b9 /core
parent2a7b99c08b29d3002183a8d7ed3acd14fbf5dc41 (diff)
downloadspark-71d705a66eb8782e5cd5c77853fdd99fd8155334.tar.gz
spark-71d705a66eb8782e5cd5c77853fdd99fd8155334.tar.bz2
spark-71d705a66eb8782e5cd5c77853fdd99fd8155334.zip
Made PairRDDFunctions taking only Tuple2, but made the rest of the shuffle code path working with general Product2.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/Aggregator.scala10
-rw-r--r--core/src/main/scala/spark/BlockStoreShuffleFetcher.scala11
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala50
-rw-r--r--core/src/main/scala/spark/RDD.scala9
-rw-r--r--core/src/main/scala/spark/ShuffleFetcher.scala5
-rw-r--r--core/src/main/scala/spark/SparkContext.scala2
-rw-r--r--core/src/main/scala/spark/api/java/JavaPairRDD.scala2
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala12
-rw-r--r--core/src/main/scala/spark/rdd/FlatMappedValuesRDD.scala2
-rw-r--r--core/src/main/scala/spark/rdd/MappedValuesRDD.scala2
-rw-r--r--core/src/main/scala/spark/rdd/OrderedRDDFunctions.scala14
-rw-r--r--core/src/main/scala/spark/rdd/ShuffledRDD.scala12
-rw-r--r--core/src/main/scala/spark/rdd/SubtractedRDD.scala14
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala2
-rw-r--r--core/src/main/scala/spark/util/MutablePair.scala34
-rw-r--r--core/src/test/scala/spark/CheckpointSuite.scala2
-rw-r--r--core/src/test/scala/spark/PairRDDFunctionsSuite.scala7
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala2
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala31
19 files changed, 132 insertions, 91 deletions
diff --git a/core/src/main/scala/spark/Aggregator.scala b/core/src/main/scala/spark/Aggregator.scala
index 3920f8511c..9af401986d 100644
--- a/core/src/main/scala/spark/Aggregator.scala
+++ b/core/src/main/scala/spark/Aggregator.scala
@@ -34,12 +34,12 @@ case class Aggregator[K, V, C] (
def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = {
val combiners = new JHashMap[K, C]
- for ((k, v) <- iter) {
- val oldC = combiners.get(k)
+ for (kv <- iter) {
+ val oldC = combiners.get(kv._1)
if (oldC == null) {
- combiners.put(k, createCombiner(v))
+ combiners.put(kv._1, createCombiner(kv._2))
} else {
- combiners.put(k, mergeValue(oldC, v))
+ combiners.put(kv._1, mergeValue(oldC, kv._2))
}
}
combiners.iterator
@@ -47,7 +47,7 @@ case class Aggregator[K, V, C] (
def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = {
val combiners = new JHashMap[K, C]
- for ((k, c) <- iter) {
+ iter.foreach { case(k, c) =>
val oldC = combiners.get(k)
if (oldC == null) {
combiners.put(k, c)
diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
index 8f6953b1f5..1ec95ed9b8 100644
--- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
@@ -28,8 +28,9 @@ import spark.util.CompletionIterator
private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
- override def fetch[K, V](
- shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer) = {
+ override def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer)
+ : Iterator[T] =
+ {
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager
@@ -49,12 +50,12 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
(address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2)))
}
- def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[(K, V)] = {
+ def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[T] = {
val blockId = blockPair._1
val blockOption = blockPair._2
blockOption match {
case Some(block) => {
- block.asInstanceOf[Iterator[(K, V)]]
+ block.asInstanceOf[Iterator[T]]
}
case None => {
val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r
@@ -73,7 +74,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
val itr = blockFetcherItr.flatMap(unpackBlock)
- CompletionIterator[(K,V), Iterator[(K,V)]](itr, {
+ CompletionIterator[T, Iterator[T]](itr, {
val shuffleMetrics = new ShuffleReadMetrics
shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index 3ae703ce1a..f8900d3921 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -48,7 +48,7 @@ import spark.Partitioner._
* Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
* Import `spark.SparkContext._` at the top of your program to use these functions.
*/
-class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[_ <: Product2[K, V]])
+class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
extends Logging
with HadoopMapReduceUtil
with Serializable {
@@ -85,13 +85,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[_ <: Produc
self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
} else if (mapSideCombine) {
val combined = self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
- val partitioned = new ShuffledRDD[K, C](combined, partitioner).setSerializer(serializerClass)
+ val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner)
+ .setSerializer(serializerClass)
partitioned.mapPartitions(aggregator.combineCombinersByKey, preservesPartitioning = true)
} else {
// Don't apply map-side combiner.
// A sanity check to make sure mergeCombiners is not defined.
assert(mergeCombiners == null)
- val values = new ShuffledRDD[K, V](self, partitioner).setSerializer(serializerClass)
+ val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass)
values.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
}
}
@@ -162,7 +163,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[_ <: Produc
throw new SparkException("reduceByKeyLocally() does not support array keys")
}
- def reducePartition(iter: Iterator[Product2[K, V]]): Iterator[JHashMap[K, V]] = {
+ def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = {
val map = new JHashMap[K, V]
for ((k, v) <- iter) {
val old = map.get(k)
@@ -236,7 +237,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[_ <: Produc
if (getKeyClass().isArray && partitioner.isInstanceOf[HashPartitioner]) {
throw new SparkException("Default partitioner cannot partition array keys.")
}
- new ShuffledRDD[K, V](self, partitioner)
+ new ShuffledRDD[K, V, (K, V)](self, partitioner)
}
/**
@@ -245,9 +246,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[_ <: Produc
* (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD.
*/
def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = {
- this.cogroup(other, partitioner).flatMapValues {
- case (vs, ws) =>
- for (v <- vs.iterator; w <- ws.iterator) yield (v, w)
+ this.cogroup(other, partitioner).flatMapValues { case (vs, ws) =>
+ for (v <- vs.iterator; w <- ws.iterator) yield (v, w)
}
}
@@ -258,13 +258,12 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[_ <: Produc
* partition the output RDD.
*/
def leftOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, Option[W]))] = {
- this.cogroup(other, partitioner).flatMapValues {
- case (vs, ws) =>
- if (ws.isEmpty) {
- vs.iterator.map(v => (v, None))
- } else {
- for (v <- vs.iterator; w <- ws.iterator) yield (v, Some(w))
- }
+ this.cogroup(other, partitioner).flatMapValues { case (vs, ws) =>
+ if (ws.isEmpty) {
+ vs.iterator.map(v => (v, None))
+ } else {
+ for (v <- vs.iterator; w <- ws.iterator) yield (v, Some(w))
+ }
}
}
@@ -276,13 +275,12 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[_ <: Produc
*/
def rightOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner)
: RDD[(K, (Option[V], W))] = {
- this.cogroup(other, partitioner).flatMapValues {
- case (vs, ws) =>
- if (vs.isEmpty) {
- ws.iterator.map(w => (None, w))
- } else {
- for (v <- vs.iterator; w <- ws.iterator) yield (Some(v), w)
- }
+ this.cogroup(other, partitioner).flatMapValues { case (vs, ws) =>
+ if (vs.isEmpty) {
+ ws.iterator.map(w => (None, w))
+ } else {
+ for (v <- vs.iterator; w <- ws.iterator) yield (Some(v), w)
+ }
}
}
@@ -378,7 +376,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[_ <: Produc
val data = self.toArray()
val map = new mutable.HashMap[K, V]
map.sizeHint(data.length)
- data.foreach { case(k, v) => map.put(k, v) }
+ data.foreach { case (k, v) => map.put(k, v) }
map
}
@@ -501,7 +499,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[_ <: Produc
self.partitioner match {
case Some(p) =>
val index = p.getPartition(key)
- def process(it: Iterator[Product2[K, V]]): Seq[V] = {
+ def process(it: Iterator[(K, V)]): Seq[V] = {
val buf = new ArrayBuffer[V]
for ((k, v) <- it if k == key) {
buf += v
@@ -559,7 +557,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[_ <: Produc
val formatter = new SimpleDateFormat("yyyyMMddHHmm")
val jobtrackerID = formatter.format(new Date())
val stageId = self.id
- def writeShard(context: spark.TaskContext, iter: Iterator[Product2[K,V]]): Int = {
+ def writeShard(context: spark.TaskContext, iter: Iterator[(K,V)]): Int = {
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
@@ -658,7 +656,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[_ <: Produc
val writer = new HadoopWriter(conf)
writer.preSetup()
- def writeToFile(context: TaskContext, iter: Iterator[Product2[K,V]]) {
+ def writeToFile(context: TaskContext, iter: Iterator[(K, V)]) {
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 04b37df212..c9a044afab 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -31,9 +31,8 @@ import org.apache.hadoop.mapred.TextOutputFormat
import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
-import spark.api.java.JavaRDD
-import spark.broadcast.Broadcast
import spark.Partitioner._
+import spark.api.java.JavaRDD
import spark.partial.BoundedDouble
import spark.partial.CountEvaluator
import spark.partial.GroupedCountEvaluator
@@ -288,7 +287,7 @@ abstract class RDD[T: ClassManifest](
if (shuffle) {
// include a shuffle step so that our upstream tasks are still distributed
new CoalescedRDD(
- new ShuffledRDD(map(x => (x, null)),
+ new ShuffledRDD[T, Null, (T, Null)](map(x => (x, null)),
new HashPartitioner(numPartitions)),
numPartitions).keys
} else {
@@ -305,8 +304,8 @@ abstract class RDD[T: ClassManifest](
def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = {
var fraction = 0.0
var total = 0
- var multiplier = 3.0
- var initialCount = this.count()
+ val multiplier = 3.0
+ val initialCount = this.count()
var maxSelected = 0
if (num < 0) {
diff --git a/core/src/main/scala/spark/ShuffleFetcher.scala b/core/src/main/scala/spark/ShuffleFetcher.scala
index dcced035e7..a6839cf7a4 100644
--- a/core/src/main/scala/spark/ShuffleFetcher.scala
+++ b/core/src/main/scala/spark/ShuffleFetcher.scala
@@ -22,12 +22,13 @@ import spark.serializer.Serializer
private[spark] abstract class ShuffleFetcher {
+
/**
* Fetch the shuffle outputs for a given ShuffleDependency.
* @return An iterator over the elements of the fetched shuffle outputs.
*/
- def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics,
- serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[(K,V)]
+ def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics,
+ serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T]
/** Stop the fetcher */
def stop() {}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 5db1767146..185c76366f 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -880,7 +880,7 @@ object SparkContext {
implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
rdd: RDD[(K, V)]) =
- new OrderedRDDFunctions(rdd)
+ new OrderedRDDFunctions[K, V, (K, V)](rdd)
implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = new DoubleRDDFunctions(rdd)
diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
index f5632428e7..effe6e5e0d 100644
--- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
@@ -564,7 +564,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
override def compare(b: K) = comp.compare(a, b)
}
implicit def toOrdered(x: K): Ordered[K] = new KeyOrdering(x)
- fromRDD(new OrderedRDDFunctions(rdd).sortByKey(ascending))
+ fromRDD(new OrderedRDDFunctions[K, V, (K, V)](rdd).sortByKey(ascending))
}
/**
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index 06e15bb73c..01b6c23dcc 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -73,10 +73,10 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
override def getDependencies: Seq[Dependency[_]] = {
rdds.map { rdd: RDD[_ <: Product2[K, _]] =>
if (rdd.partitioner == Some(part)) {
- logInfo("Adding one-to-one dependency with " + rdd)
+ logDebug("Adding one-to-one dependency with " + rdd)
new OneToOneDependency(rdd)
} else {
- logInfo("Adding shuffle dependency with " + rdd)
+ logDebug("Adding shuffle dependency with " + rdd)
new ShuffleDependency[Any, Any](rdd, part, serializerClass)
}
}
@@ -122,15 +122,15 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
// Read them from the parent
- for ((k, v) <- rdd.iterator(itsSplit, context)) {
- getSeq(k.asInstanceOf[K])(depNum) += v
+ rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]].foreach { kv =>
+ getSeq(kv._1)(depNum) += kv._2
}
}
case ShuffleCoGroupSplitDep(shuffleId) => {
// Read map outputs of shuffle
val fetcher = SparkEnv.get.shuffleFetcher
- fetcher.fetch[K, Any](shuffleId, split.index, context.taskMetrics, ser).foreach {
- case (key, value) => getSeq(key)(depNum) += value
+ fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context.taskMetrics, ser).foreach {
+ kv => getSeq(kv._1)(depNum) += kv._2
}
}
}
diff --git a/core/src/main/scala/spark/rdd/FlatMappedValuesRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedValuesRDD.scala
index 05fdfd82c1..a6bdce89d8 100644
--- a/core/src/main/scala/spark/rdd/FlatMappedValuesRDD.scala
+++ b/core/src/main/scala/spark/rdd/FlatMappedValuesRDD.scala
@@ -29,7 +29,7 @@ class FlatMappedValuesRDD[K, V, U](prev: RDD[_ <: Product2[K, V]], f: V => Trave
override val partitioner = firstParent[Product2[K, V]].partitioner
override def compute(split: Partition, context: TaskContext) = {
- firstParent[Product2[K, V]].iterator(split, context).flatMap { case (k, v) =>
+ firstParent[Product2[K, V]].iterator(split, context).flatMap { case Product2(k, v) =>
f(v).map(x => (k, x))
}
}
diff --git a/core/src/main/scala/spark/rdd/MappedValuesRDD.scala b/core/src/main/scala/spark/rdd/MappedValuesRDD.scala
index 21ae97daa9..8334e3b557 100644
--- a/core/src/main/scala/spark/rdd/MappedValuesRDD.scala
+++ b/core/src/main/scala/spark/rdd/MappedValuesRDD.scala
@@ -29,6 +29,6 @@ class MappedValuesRDD[K, V, U](prev: RDD[_ <: Product2[K, V]], f: V => U)
override val partitioner = firstParent[Product2[K, U]].partitioner
override def compute(split: Partition, context: TaskContext): Iterator[(K, U)] = {
- firstParent[Product2[K, V]].iterator(split, context).map { case(k ,v) => (k, f(v)) }
+ firstParent[Product2[K, V]].iterator(split, context).map { case Product2(k ,v) => (k, f(v)) }
}
}
diff --git a/core/src/main/scala/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/spark/rdd/OrderedRDDFunctions.scala
index 6328c6a4ac..9154b76035 100644
--- a/core/src/main/scala/spark/rdd/OrderedRDDFunctions.scala
+++ b/core/src/main/scala/spark/rdd/OrderedRDDFunctions.scala
@@ -24,8 +24,10 @@ import spark.{RangePartitioner, Logging, RDD}
* an implicit conversion. Import `spark.SparkContext._` at the top of your program to use these
* functions. They will work with any key type that has a `scala.math.Ordered` implementation.
*/
-class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
- self: RDD[_ <: Product2[K, V]])
+class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest,
+ V: ClassManifest,
+ P <: Product2[K, V] : ClassManifest](
+ self: RDD[P])
extends Logging with Serializable {
/**
@@ -34,11 +36,9 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
* (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
* order of the keys).
*/
- def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size)
- : RDD[(K, V)] =
- {
- val part = new RangePartitioner(numPartitions, self.asInstanceOf[RDD[Product2[K,V]]], ascending)
- val shuffled = new ShuffledRDD[K, V](self, part)
+ def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = {
+ val part = new RangePartitioner(numPartitions, self, ascending)
+ val shuffled = new ShuffledRDD[K, V, P](self, part)
shuffled.mapPartitions(iter => {
val buf = iter.toArray
if (ascending) {
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index 2eac62f9c0..51c05af064 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -32,14 +32,14 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
* @tparam K the key class.
* @tparam V the value class.
*/
-class ShuffledRDD[K, V](
- @transient var prev: RDD[_ <: Product2[K, V]],
+class ShuffledRDD[K, V, P <: Product2[K, V] : ClassManifest](
+ @transient var prev: RDD[P],
part: Partitioner)
- extends RDD[(K, V)](prev.context, Nil) {
+ extends RDD[P](prev.context, Nil) {
private var serializerClass: String = null
- def setSerializer(cls: String): ShuffledRDD[K, V] = {
+ def setSerializer(cls: String): ShuffledRDD[K, V, P] = {
serializerClass = cls
this
}
@@ -54,9 +54,9 @@ class ShuffledRDD[K, V](
Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i))
}
- override def compute(split: Partition, context: TaskContext): Iterator[(K, V)] = {
+ override def compute(split: Partition, context: TaskContext): Iterator[P] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
- SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index, context.taskMetrics,
+ SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context.taskMetrics,
SparkEnv.get.serializerManager.get(serializerClass))
}
diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
index 200e85d432..dadef5e17d 100644
--- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
@@ -62,10 +62,10 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM
override def getDependencies: Seq[Dependency[_]] = {
Seq(rdd1, rdd2).map { rdd =>
if (rdd.partitioner == Some(part)) {
- logInfo("Adding one-to-one dependency with " + rdd)
+ logDebug("Adding one-to-one dependency with " + rdd)
new OneToOneDependency(rdd)
} else {
- logInfo("Adding shuffle dependency with " + rdd)
+ logDebug("Adding shuffle dependency with " + rdd)
new ShuffleDependency(rdd, part, serializerClass)
}
}
@@ -103,16 +103,14 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM
seq
}
}
- def integrate(dep: CoGroupSplitDep, op: ((K, V)) => Unit) = dep match {
+ def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit) = dep match {
case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
- for (t <- rdd.iterator(itsSplit, context))
- op(t.asInstanceOf[(K, V)])
+ rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op)
}
case ShuffleCoGroupSplitDep(shuffleId) => {
- val iter = SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index,
+ val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index,
context.taskMetrics, serializer)
- for (t <- iter)
- op(t.asInstanceOf[(K, V)])
+ iter.foreach(op)
}
}
// the first dep is rdd1; add all values to the map
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index e3bb6d1e60..121ff31121 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -148,7 +148,7 @@ private[spark] class ShuffleMapTask(
// Write the map output to its associated buckets.
for (elem <- rdd.iterator(split, taskContext)) {
- val pair = elem.asInstanceOf[(Any, Any)]
+ val pair = elem.asInstanceOf[Product2[Any, Any]]
val bucketId = dep.partitioner.getPartition(pair._1)
buckets.writers(bucketId).write(pair)
}
diff --git a/core/src/main/scala/spark/util/MutablePair.scala b/core/src/main/scala/spark/util/MutablePair.scala
new file mode 100644
index 0000000000..117218bf47
--- /dev/null
+++ b/core/src/main/scala/spark/util/MutablePair.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.util
+
+
+/** A tuple of 2 elements.
+ * @param _1 Element 1 of this MutablePair
+ * @param _2 Element 2 of this MutablePair
+ */
+case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T1,
+ @specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T2]
+ (var _1: T1,var _2: T2)
+ extends Product2[T1, T2]
+{
+
+ override def toString = "(" + _1 + "," + _2 + ")"
+
+ def canEqual(that: Any): Boolean = that.isInstanceOf[MutablePair[T1, T2]]
+}
diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala
index a84c89e3c9..966dede2be 100644
--- a/core/src/test/scala/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/spark/CheckpointSuite.scala
@@ -99,7 +99,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
test("ShuffledRDD") {
testCheckpointing(rdd => {
// Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD
- new ShuffledRDD(rdd.map(x => (x % 2, 1)), partitioner)
+ new ShuffledRDD[Int, Int, (Int, Int)](rdd.map(x => (x % 2, 1)), partitioner)
})
}
diff --git a/core/src/test/scala/spark/PairRDDFunctionsSuite.scala b/core/src/test/scala/spark/PairRDDFunctionsSuite.scala
index b102eaf4e6..328b3b5497 100644
--- a/core/src/test/scala/spark/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/spark/PairRDDFunctionsSuite.scala
@@ -21,16 +21,11 @@ import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
-import org.scalatest.prop.Checkers
-import org.scalacheck.Arbitrary._
-import org.scalacheck.Gen
-import org.scalacheck.Prop._
import com.google.common.io.Files
-
-import spark.rdd.ShuffledRDD
import spark.SparkContext._
+
class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
test("groupByKey") {
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index cbddf4e523..75778de1cc 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -170,7 +170,7 @@ class RDDSuite extends FunSuite with SharedSparkContext {
// we can optionally shuffle to keep the upstream parallel
val coalesced5 = data.coalesce(1, shuffle = true)
- assert(coalesced5.dependencies.head.rdd.dependencies.head.rdd.asInstanceOf[ShuffledRDD[_, _]] !=
+ assert(coalesced5.dependencies.head.rdd.dependencies.head.rdd.asInstanceOf[ShuffledRDD[_, _, _]] !=
null)
}
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index c686b8cc5a..c319a57fdd 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -22,6 +22,9 @@ import org.scalatest.matchers.ShouldMatchers
import spark.rdd.ShuffledRDD
import spark.SparkContext._
+import spark.ShuffleSuite.NonJavaSerializableClass
+import spark.util.MutablePair
+
class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
test("groupByKey without compression") {
@@ -46,12 +49,12 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
val a = sc.parallelize(1 to 10, 2)
val b = a.map { x =>
- (x, new ShuffleSuite.NonJavaSerializableClass(x * 2))
+ (x, new NonJavaSerializableClass(x * 2))
}
// If the Kryo serializer is not used correctly, the shuffle would fail because the
// default Java serializer cannot handle the non serializable class.
- val c = new ShuffledRDD(b, new HashPartitioner(NUM_BLOCKS))
- .setSerializer(classOf[spark.KryoSerializer].getName)
+ val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)](
+ b, new HashPartitioner(NUM_BLOCKS)).setSerializer(classOf[spark.KryoSerializer].getName)
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
assert(c.count === 10)
@@ -68,12 +71,12 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
sc = new SparkContext("local-cluster[2,1,512]", "test")
val a = sc.parallelize(1 to 10, 2)
val b = a.map { x =>
- (x, new ShuffleSuite.NonJavaSerializableClass(x * 2))
+ (x, new NonJavaSerializableClass(x * 2))
}
// If the Kryo serializer is not used correctly, the shuffle would fail because the
// default Java serializer cannot handle the non serializable class.
- val c = new ShuffledRDD(b, new HashPartitioner(3))
- .setSerializer(classOf[spark.KryoSerializer].getName)
+ val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)](
+ b, new HashPartitioner(3)).setSerializer(classOf[spark.KryoSerializer].getName)
assert(c.count === 10)
}
@@ -88,7 +91,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
// NOTE: The default Java serializer doesn't create zero-sized blocks.
// So, use Kryo
- val c = new ShuffledRDD(b, new HashPartitioner(10))
+ val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10))
.setSerializer(classOf[spark.KryoSerializer].getName)
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
@@ -114,7 +117,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
val b = a.map(x => (x, x*2))
// NOTE: The default Java serializer should create zero-sized blocks
- val c = new ShuffledRDD(b, new HashPartitioner(10))
+ val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10))
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
assert(c.count === 4)
@@ -128,6 +131,18 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
// We should have at most 4 non-zero sized partitions
assert(nonEmptyBlocks.size <= 4)
}
+
+ test("shuffle using mutable pairs") {
+ // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+ def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2)
+ val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1))
+ val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data)
+ val results = new ShuffledRDD[Int, Int, MutablePair[Int, Int]](pairs, new HashPartitioner(2))
+ .collect()
+
+ data.foreach { pair => results should contain (pair) }
+ }
}
object ShuffleSuite {