aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/rdd/SubtractedRDD.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/spark/rdd/SubtractedRDD.scala')
-rw-r--r--core/src/main/scala/spark/rdd/SubtractedRDD.scala18
1 files changed, 13 insertions, 5 deletions
diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
index 5e56900b18..9274839bca 100644
--- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
@@ -15,6 +15,7 @@ import spark.SparkEnv
import spark.ShuffleDependency
import spark.OneToOneDependency
+
/**
* An optimized version of cogroup for set difference/subtraction.
*
@@ -34,7 +35,9 @@ import spark.OneToOneDependency
private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
@transient var rdd1: RDD[(K, V)],
@transient var rdd2: RDD[(K, W)],
- part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) {
+ part: Partitioner,
+ val serializerClass: String = null)
+ extends RDD[(K, V)](rdd1.context, Nil) {
override def getDependencies: Seq[Dependency[_]] = {
Seq(rdd1, rdd2).map { rdd =>
@@ -43,7 +46,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
new OneToOneDependency(rdd)
} else {
logInfo("Adding shuffle dependency with " + rdd)
- new ShuffleDependency(rdd.asInstanceOf[RDD[(K, Any)]], part)
+ new ShuffleDependency(rdd.asInstanceOf[RDD[(K, Any)]], part, serializerClass)
}
}
}
@@ -68,6 +71,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
val partition = p.asInstanceOf[CoGroupPartition]
+ val serializer = SparkEnv.get.serializerManager.get(serializerClass)
val map = new JHashMap[K, ArrayBuffer[V]]
def getSeq(k: K): ArrayBuffer[V] = {
val seq = map.get(k)
@@ -80,12 +84,16 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
}
}
def integrate(dep: CoGroupSplitDep, op: ((K, V)) => Unit) = dep match {
- case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
+ case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
for (t <- rdd.iterator(itsSplit, context))
op(t.asInstanceOf[(K, V)])
- case ShuffleCoGroupSplitDep(shuffleId) =>
- for (t <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, context.taskMetrics))
+ }
+ case ShuffleCoGroupSplitDep(shuffleId) => {
+ val iter = SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index,
+ context.taskMetrics, serializer)
+ for (t <- iter)
op(t.asInstanceOf[(K, V)])
+ }
}
// the first dep is rdd1; add all values to the map
integrate(partition.deps(0), t => getSeq(t._1) += t._2)