aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorReynold Xin <reynoldx@gmail.com>2013-08-17 21:43:29 -0700
committerReynold Xin <reynoldx@gmail.com>2013-08-17 21:43:29 -0700
commit2c00ea3efc7d9a23af8ba11352460294e1865942 (patch)
treea72c858cf44b872fd6d14a31e015d3e02c627727 /core
parent0e84fee76b529089fb52f15151202e9a7b847ed5 (diff)
downloadspark-2c00ea3efc7d9a23af8ba11352460294e1865942.tar.gz
spark-2c00ea3efc7d9a23af8ba11352460294e1865942.tar.bz2
spark-2c00ea3efc7d9a23af8ba11352460294e1865942.zip
Moved shuffle serializer setting from a constructor parameter to a setSerializer method in various RDDs that involve shuffle operations.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala13
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala12
-rw-r--r--core/src/main/scala/spark/rdd/ShuffledRDD.scala29
-rw-r--r--core/src/main/scala/spark/rdd/SubtractedRDD.scala10
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala19
5 files changed, 51 insertions, 32 deletions
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index fa9df3a97e..0be4b4feb8 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -85,17 +85,18 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (self.partitioner == Some(partitioner)) {
- self.mapPartitions(aggregator.combineValuesByKey(_), true)
+ self.mapPartitions(aggregator.combineValuesByKey, true)
} else if (mapSideCombine) {
- val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true)
- val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner, serializerClass)
- partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true)
+ val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey, true)
+ val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner)
+ .setSerializer(serializerClass)
+ partitioned.mapPartitions(aggregator.combineCombinersByKey, true)
} else {
// Don't apply map-side combiner.
// A sanity check to make sure mergeCombiners is not defined.
assert(mergeCombiners == null)
- val values = new ShuffledRDD[K, V](self, partitioner, serializerClass)
- values.mapPartitions(aggregator.combineValuesByKey(_), true)
+ val values = new ShuffledRDD[K, V](self, partitioner).setSerializer(serializerClass)
+ values.mapPartitions(aggregator.combineValuesByKey, true)
}
}
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index 019b12d2d5..c2d95dc060 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -60,12 +60,16 @@ class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
* @param rdds parent RDDs.
* @param part partitioner used to partition the shuffle output.
*/
-class CoGroupedRDD[K](
- @transient var rdds: Seq[RDD[(K, _)]],
- part: Partitioner,
- val serializerClass: String = null)
+class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) {
+ private var serializerClass: String = null
+
+ def setSerializer(cls: String): CoGroupedRDD[K] = {
+ serializerClass = cls
+ this
+ }
+
override def getDependencies: Seq[Dependency[_]] = {
rdds.map { rdd: RDD[(K, _)] =>
if (rdd.partitioner == Some(part)) {
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index 0137f80953..bcf7d0d89c 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -17,8 +17,9 @@
package spark.rdd
-import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Partition, TaskContext}
-import spark.SparkContext._
+import spark._
+import scala.Some
+import scala.Some
private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
@@ -30,15 +31,24 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
* The resulting RDD from a shuffle (e.g. repartitioning of data).
* @param prev the parent RDD.
* @param part the partitioner used to partition the RDD
- * @param serializerClass class name of the serializer to use.
* @tparam K the key class.
* @tparam V the value class.
*/
class ShuffledRDD[K, V](
- @transient prev: RDD[(K, V)],
- part: Partitioner,
- serializerClass: String = null)
- extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part, serializerClass))) {
+ @transient var prev: RDD[(K, V)],
+ part: Partitioner)
+ extends RDD[(K, V)](prev.context, Nil) {
+
+ private var serializerClass: String = null
+
+ def setSerializer(cls: String): ShuffledRDD[K, V] = {
+ serializerClass = cls
+ this
+ }
+
+ override def getDependencies: Seq[Dependency[_]] = {
+ List(new ShuffleDependency(prev, part, serializerClass))
+ }
override val partitioner = Some(part)
@@ -51,4 +61,9 @@ class ShuffledRDD[K, V](
SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index, context.taskMetrics,
SparkEnv.get.serializerManager.get(serializerClass))
}
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ prev = null
+ }
}
diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
index 0402b9f250..46b8cafaac 100644
--- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
@@ -49,10 +49,16 @@ import spark.OneToOneDependency
private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassManifest](
@transient var rdd1: RDD[(K, V)],
@transient var rdd2: RDD[(K, W)],
- part: Partitioner,
- val serializerClass: String = null)
+ part: Partitioner)
extends RDD[(K, V)](rdd1.context, Nil) {
+ private var serializerClass: String = null
+
+ def setSerializer(cls: String): SubtractedRDD[K, V, W] = {
+ serializerClass = cls
+ this
+ }
+
override def getDependencies: Seq[Dependency[_]] = {
Seq(rdd1, rdd2).map { rdd =>
if (rdd.partitioner == Some(part)) {
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index 752e4b85e6..c686b8cc5a 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -17,17 +17,8 @@
package spark
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashSet
-
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
-import org.scalatest.prop.Checkers
-import org.scalacheck.Arbitrary._
-import org.scalacheck.Gen
-import org.scalacheck.Prop._
-
-import com.google.common.io.Files
import spark.rdd.ShuffledRDD
import spark.SparkContext._
@@ -59,8 +50,8 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
}
// If the Kryo serializer is not used correctly, the shuffle would fail because the
// default Java serializer cannot handle the non serializable class.
- val c = new ShuffledRDD(b, new HashPartitioner(NUM_BLOCKS),
- classOf[spark.KryoSerializer].getName)
+ val c = new ShuffledRDD(b, new HashPartitioner(NUM_BLOCKS))
+ .setSerializer(classOf[spark.KryoSerializer].getName)
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
assert(c.count === 10)
@@ -81,7 +72,8 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
}
// If the Kryo serializer is not used correctly, the shuffle would fail because the
// default Java serializer cannot handle the non serializable class.
- val c = new ShuffledRDD(b, new HashPartitioner(3), classOf[spark.KryoSerializer].getName)
+ val c = new ShuffledRDD(b, new HashPartitioner(3))
+ .setSerializer(classOf[spark.KryoSerializer].getName)
assert(c.count === 10)
}
@@ -96,7 +88,8 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
// NOTE: The default Java serializer doesn't create zero-sized blocks.
// So, use Kryo
- val c = new ShuffledRDD(b, new HashPartitioner(10), classOf[spark.KryoSerializer].getName)
+ val c = new ShuffledRDD(b, new HashPartitioner(10))
+ .setSerializer(classOf[spark.KryoSerializer].getName)
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
assert(c.count === 4)