aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2014-03-16 09:57:21 -0700
committerPatrick Wendell <pwendell@gmail.com>2014-03-16 09:57:21 -0700
commitf5486e9f75d62919583da5ecf9a9ad00222b2227 (patch)
tree42bde2b308647eeaef2c7a92aad176916d884310 /core
parent97e4459e1e4cca8696535e10a91733c15f960107 (diff)
downloadspark-f5486e9f75d62919583da5ecf9a9ad00222b2227.tar.gz
spark-f5486e9f75d62919583da5ecf9a9ad00222b2227.tar.bz2
spark-f5486e9f75d62919583da5ecf9a9ad00222b2227.zip
SPARK-1255: Allow user to pass Serializer object instead of class name for shuffle.
This is more general than simply passing a string name and leaves more room for performance optimizations. Note that this is technically an API breaking change in the following two ways: 1. The shuffle serializer specification in ShuffleDependency now require an object instead of a String (of the class name), but I suspect nobody else in this world has used this API other than me in GraphX and Shark. 2. Serializer's in Spark from now on are required to be serializable. Author: Reynold Xin <rxin@apache.org> Closes #149 from rxin/serializer and squashes the following commits: 5acaccd [Reynold Xin] Properly call serializer's constructors. 2a8d75a [Reynold Xin] Added more documentation for the serializer option in ShuffleDependency. 7420185 [Reynold Xin] Allow user to pass Serializer object instead of class name for shuffle.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/Dependency.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/ShuffleFetcher.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala27
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/Serializer.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala75
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleSuite.scala9
14 files changed, 99 insertions, 139 deletions
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index cc30105940..448f87b81e 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -18,6 +18,7 @@
package org.apache.spark
import org.apache.spark.rdd.RDD
+import org.apache.spark.serializer.Serializer
/**
* Base class for dependencies.
@@ -43,12 +44,13 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
* Represents a dependency on the output of a shuffle stage.
* @param rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
- * @param serializerClass class name of the serializer to use
+ * @param serializer [[Serializer]] to use. If set to null, the default serializer, as specified
+ * by `spark.serializer` config option, will be used.
*/
class ShuffleDependency[K, V](
@transient rdd: RDD[_ <: Product2[K, V]],
val partitioner: Partitioner,
- val serializerClass: String = null)
+ val serializer: Serializer = null)
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
val shuffleId: Int = rdd.context.newShuffleId()
diff --git a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
index e8f756c408..a4f69b6b22 100644
--- a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
@@ -29,7 +29,7 @@ private[spark] abstract class ShuffleFetcher {
shuffleId: Int,
reduceId: Int,
context: TaskContext,
- serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T]
+ serializer: Serializer = SparkEnv.get.serializer): Iterator[T]
/** Stop the fetcher */
def stop() {}
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 5e43b51984..d035d909b7 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -28,7 +28,7 @@ import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.storage.{BlockManager, BlockManagerMaster, BlockManagerMasterActor}
import org.apache.spark.network.ConnectionManager
-import org.apache.spark.serializer.{Serializer, SerializerManager}
+import org.apache.spark.serializer.Serializer
import org.apache.spark.util.{AkkaUtils, Utils}
/**
@@ -41,7 +41,6 @@ import org.apache.spark.util.{AkkaUtils, Utils}
class SparkEnv private[spark] (
val executorId: String,
val actorSystem: ActorSystem,
- val serializerManager: SerializerManager,
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheManager: CacheManager,
@@ -139,16 +138,22 @@ object SparkEnv extends Logging {
// defaultClassName if the property is not set, and return it as a T
def instantiateClass[T](propertyName: String, defaultClassName: String): T = {
val name = conf.get(propertyName, defaultClassName)
- Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
+ val cls = Class.forName(name, true, classLoader)
+ // First try with the constructor that takes SparkConf. If we can't find one,
+ // use a no-arg constructor instead.
+ try {
+ cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
+ } catch {
+ case _: NoSuchMethodException =>
+ cls.getConstructor().newInstance().asInstanceOf[T]
+ }
}
- val serializerManager = new SerializerManager
- val serializer = serializerManager.setDefault(
- conf.get("spark.serializer", "org.apache.spark.serializer.JavaSerializer"), conf)
+ val serializer = instantiateClass[Serializer](
+ "spark.serializer", "org.apache.spark.serializer.JavaSerializer")
- val closureSerializer = serializerManager.get(
- conf.get("spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer"),
- conf)
+ val closureSerializer = instantiateClass[Serializer](
+ "spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer")
def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
if (isDriver) {
@@ -220,7 +225,6 @@ object SparkEnv extends Logging {
new SparkEnv(
executorId,
actorSystem,
- serializerManager,
serializer,
closureSerializer,
cacheManager,
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 699a10c96c..8561711931 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap}
+import org.apache.spark.serializer.Serializer
private[spark] sealed trait CoGroupSplitDep extends Serializable
@@ -66,10 +67,10 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
private type CoGroupValue = (Any, Int) // Int is dependency number
private type CoGroupCombiner = Seq[CoGroup]
- private var serializerClass: String = null
+ private var serializer: Serializer = null
- def setSerializer(cls: String): CoGroupedRDD[K] = {
- serializerClass = cls
+ def setSerializer(serializer: Serializer): CoGroupedRDD[K] = {
+ this.serializer = serializer
this
}
@@ -80,7 +81,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
new OneToOneDependency(rdd)
} else {
logDebug("Adding shuffle dependency with " + rdd)
- new ShuffleDependency[Any, Any](rdd, part, serializerClass)
+ new ShuffleDependency[Any, Any](rdd, part, serializer)
}
}
}
@@ -113,18 +114,17 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
// A list of (rdd iterator, dependency number) pairs
val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)]
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
- case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
+ case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
// Read them from the parent
val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]]
rddIterators += ((it, depNum))
- }
- case ShuffleCoGroupSplitDep(shuffleId) => {
+
+ case ShuffleCoGroupSplitDep(shuffleId) =>
// Read map outputs of shuffle
val fetcher = SparkEnv.get.shuffleFetcher
- val ser = SparkEnv.get.serializerManager.get(serializerClass, sparkConf)
+ val ser = Serializer.getSerializer(serializer)
val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser)
rddIterators += ((it, depNum))
- }
}
if (!externalSorting) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index b20ed99f89..b0d322fe27 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -44,6 +44,7 @@ import org.apache.spark._
import org.apache.spark.Partitioner.defaultPartitioner
import org.apache.spark.SparkContext._
import org.apache.spark.partial.{BoundedDouble, PartialResult}
+import org.apache.spark.serializer.Serializer
import org.apache.spark.util.SerializableHyperLogLog
/**
@@ -73,7 +74,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
mergeCombiners: (C, C) => C,
partitioner: Partitioner,
mapSideCombine: Boolean = true,
- serializerClass: String = null): RDD[(K, C)] = {
+ serializer: Serializer = null): RDD[(K, C)] = {
require(mergeCombiners != null, "mergeCombiners must be defined") // required as of Spark 0.9.0
if (getKeyClass().isArray) {
if (mapSideCombine) {
@@ -93,13 +94,13 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
aggregator.combineValuesByKey(iter, context)
}, preservesPartitioning = true)
val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner)
- .setSerializer(serializerClass)
+ .setSerializer(serializer)
partitioned.mapPartitionsWithContext((context, iter) => {
new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter, context))
}, preservesPartitioning = true)
} else {
// Don't apply map-side combiner.
- val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass)
+ val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializer)
values.mapPartitionsWithContext((context, iter) => {
new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
}, preservesPartitioning = true)
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index 0bbda25a90..02660ea6a4 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -20,6 +20,7 @@ package org.apache.spark.rdd
import scala.reflect.ClassTag
import org.apache.spark.{Dependency, Partition, Partitioner, ShuffleDependency, SparkEnv, TaskContext}
+import org.apache.spark.serializer.Serializer
private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
override val index = idx
@@ -38,15 +39,15 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
part: Partitioner)
extends RDD[P](prev.context, Nil) {
- private var serializerClass: String = null
+ private var serializer: Serializer = null
- def setSerializer(cls: String): ShuffledRDD[K, V, P] = {
- serializerClass = cls
+ def setSerializer(serializer: Serializer): ShuffledRDD[K, V, P] = {
+ this.serializer = serializer
this
}
override def getDependencies: Seq[Dependency[_]] = {
- List(new ShuffleDependency(prev, part, serializerClass))
+ List(new ShuffleDependency(prev, part, serializer))
}
override val partitioner = Some(part)
@@ -57,8 +58,8 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
override def compute(split: Partition, context: TaskContext): Iterator[P] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
- SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context,
- SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf))
+ val ser = Serializer.getSerializer(serializer)
+ SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, ser)
}
override def clearDependencies() {
diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
index 5fe9f363db..9a09c05bbc 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
@@ -30,6 +30,7 @@ import org.apache.spark.Partitioner
import org.apache.spark.ShuffleDependency
import org.apache.spark.SparkEnv
import org.apache.spark.TaskContext
+import org.apache.spark.serializer.Serializer
/**
* An optimized version of cogroup for set difference/subtraction.
@@ -53,10 +54,10 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
part: Partitioner)
extends RDD[(K, V)](rdd1.context, Nil) {
- private var serializerClass: String = null
+ private var serializer: Serializer = null
- def setSerializer(cls: String): SubtractedRDD[K, V, W] = {
- serializerClass = cls
+ def setSerializer(serializer: Serializer): SubtractedRDD[K, V, W] = {
+ this.serializer = serializer
this
}
@@ -67,7 +68,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
new OneToOneDependency(rdd)
} else {
logDebug("Adding shuffle dependency with " + rdd)
- new ShuffleDependency(rdd, part, serializerClass)
+ new ShuffleDependency(rdd, part, serializer)
}
}
}
@@ -92,7 +93,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
val partition = p.asInstanceOf[CoGroupPartition]
- val serializer = SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf)
+ val ser = Serializer.getSerializer(serializer)
val map = new JHashMap[K, ArrayBuffer[V]]
def getSeq(k: K): ArrayBuffer[V] = {
val seq = map.get(k)
@@ -105,14 +106,13 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
}
}
def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit) = dep match {
- case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
+ case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op)
- }
- case ShuffleCoGroupSplitDep(shuffleId) => {
+
+ case ShuffleCoGroupSplitDep(shuffleId) =>
val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index,
- context, serializer)
+ context, ser)
iter.foreach(op)
- }
}
// the first dep is rdd1; add all values to the map
integrate(partition.deps(0), t => getSeq(t._1) += t._2)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 77789031f4..2a9edf4a76 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -26,6 +26,7 @@ import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.RDDCheckpointData
+import org.apache.spark.serializer.Serializer
import org.apache.spark.storage._
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
@@ -153,7 +154,7 @@ private[spark] class ShuffleMapTask(
try {
// Obtain all the block writers for shuffle blocks.
- val ser = SparkEnv.get.serializerManager.get(dep.serializerClass, SparkEnv.get.conf)
+ val ser = Serializer.getSerializer(dep.serializer)
shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser)
// Write the map output to its associated buckets.
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index bfa647f7f0..18a68b05fa 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -23,11 +23,10 @@ import java.nio.ByteBuffer
import org.apache.spark.SparkConf
import org.apache.spark.util.ByteBufferInputStream
-private[spark] class JavaSerializationStream(out: OutputStream, conf: SparkConf)
+private[spark] class JavaSerializationStream(out: OutputStream, counterReset: Int)
extends SerializationStream {
- val objOut = new ObjectOutputStream(out)
- var counter = 0
- val counterReset = conf.getInt("spark.serializer.objectStreamReset", 10000)
+ private val objOut = new ObjectOutputStream(out)
+ private var counter = 0
/**
* Calling reset to avoid memory leak:
@@ -51,7 +50,7 @@ private[spark] class JavaSerializationStream(out: OutputStream, conf: SparkConf)
private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader)
extends DeserializationStream {
- val objIn = new ObjectInputStream(in) {
+ private val objIn = new ObjectInputStream(in) {
override def resolveClass(desc: ObjectStreamClass) =
Class.forName(desc.getName, false, loader)
}
@@ -60,7 +59,7 @@ extends DeserializationStream {
def close() { objIn.close() }
}
-private[spark] class JavaSerializerInstance(conf: SparkConf) extends SerializerInstance {
+private[spark] class JavaSerializerInstance(counterReset: Int) extends SerializerInstance {
def serialize[T](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream()
val out = serializeStream(bos)
@@ -82,7 +81,7 @@ private[spark] class JavaSerializerInstance(conf: SparkConf) extends SerializerI
}
def serializeStream(s: OutputStream): SerializationStream = {
- new JavaSerializationStream(s, conf)
+ new JavaSerializationStream(s, counterReset)
}
def deserializeStream(s: InputStream): DeserializationStream = {
@@ -97,6 +96,16 @@ private[spark] class JavaSerializerInstance(conf: SparkConf) extends SerializerI
/**
* A Spark serializer that uses Java's built-in serialization.
*/
-class JavaSerializer(conf: SparkConf) extends Serializer {
- def newInstance(): SerializerInstance = new JavaSerializerInstance(conf)
+class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable {
+ private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 10000)
+
+ def newInstance(): SerializerInstance = new JavaSerializerInstance(counterReset)
+
+ override def writeExternal(out: ObjectOutput) {
+ out.writeInt(counterReset)
+ }
+
+ override def readExternal(in: ObjectInput) {
+ counterReset = in.readInt()
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 920490f9d0..6b6d814c1f 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -34,10 +34,14 @@ import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock}
/**
* A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]].
*/
-class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serializer with Logging {
- private val bufferSize = {
- conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024
- }
+class KryoSerializer(conf: SparkConf)
+ extends org.apache.spark.serializer.Serializer
+ with Logging
+ with Serializable {
+
+ private val bufferSize = conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024
+ private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true)
+ private val registrator = conf.getOption("spark.kryo.registrator")
def newKryoOutput() = new KryoOutput(bufferSize)
@@ -48,7 +52,7 @@ class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serial
// Allow disabling Kryo reference tracking if user knows their object graphs don't have loops.
// Do this before we invoke the user registrator so the user registrator can override this.
- kryo.setReferences(conf.getBoolean("spark.kryo.referenceTracking", true))
+ kryo.setReferences(referenceTracking)
for (cls <- KryoSerializer.toRegister) kryo.register(cls)
@@ -58,7 +62,7 @@ class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serial
// Allow the user to register their own classes by setting spark.kryo.registrator
try {
- for (regCls <- conf.getOption("spark.kryo.registrator")) {
+ for (regCls <- registrator) {
logDebug("Running user registrator: " + regCls)
val reg = Class.forName(regCls, true, classLoader).newInstance()
.asInstanceOf[KryoRegistrator]
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index 16677ab54b..099143494b 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -23,21 +23,31 @@ import java.nio.ByteBuffer
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
import org.apache.spark.util.{ByteBufferInputStream, NextIterator}
+import org.apache.spark.SparkEnv
/**
* A serializer. Because some serialization libraries are not thread safe, this class is used to
* create [[org.apache.spark.serializer.SerializerInstance]] objects that do the actual
* serialization and are guaranteed to only be called from one thread at a time.
*
- * Implementations of this trait should have a zero-arg constructor or a constructor that accepts a
- * [[org.apache.spark.SparkConf]] as parameter. If both constructors are defined, the latter takes
- * precedence.
+ * Implementations of this trait should implement:
+ * 1. a zero-arg constructor or a constructor that accepts a [[org.apache.spark.SparkConf]]
+ * as parameter. If both constructors are defined, the latter takes precedence.
+ *
+ * 2. Java serialization interface.
*/
trait Serializer {
def newInstance(): SerializerInstance
}
+object Serializer {
+ def getSerializer(serializer: Serializer): Serializer = {
+ if (serializer == null) SparkEnv.get.serializer else serializer
+ }
+}
+
+
/**
* An instance of a serializer, for use by one thread at a time.
*/
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
deleted file mode 100644
index 65ac0155f4..0000000000
--- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
+++ /dev/null
@@ -1,75 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.serializer
-
-import java.util.concurrent.ConcurrentHashMap
-
-import org.apache.spark.SparkConf
-
-/**
- * A service that returns a serializer object given the serializer's class name. If a previous
- * instance of the serializer object has been created, the get method returns that instead of
- * creating a new one.
- */
-private[spark] class SerializerManager {
- // TODO: Consider moving this into SparkConf itself to remove the global singleton.
-
- private val serializers = new ConcurrentHashMap[String, Serializer]
- private var _default: Serializer = _
-
- def default = _default
-
- def setDefault(clsName: String, conf: SparkConf): Serializer = {
- _default = get(clsName, conf)
- _default
- }
-
- def get(clsName: String, conf: SparkConf): Serializer = {
- if (clsName == null) {
- default
- } else {
- var serializer = serializers.get(clsName)
- if (serializer != null) {
- // If the serializer has been created previously, reuse that.
- serializer
- } else this.synchronized {
- // Otherwise, create a new one. But make sure no other thread has attempted
- // to create another new one at the same time.
- serializer = serializers.get(clsName)
- if (serializer == null) {
- val clsLoader = Thread.currentThread.getContextClassLoader
- val cls = Class.forName(clsName, true, clsLoader)
-
- // First try with the constructor that takes SparkConf. If we can't find one,
- // use a no-arg constructor instead.
- try {
- val constructor = cls.getConstructor(classOf[SparkConf])
- serializer = constructor.newInstance(conf).asInstanceOf[Serializer]
- } catch {
- case _: NoSuchMethodException =>
- val constructor = cls.getConstructor()
- serializer = constructor.newInstance().asInstanceOf[Serializer]
- }
-
- serializers.put(clsName, serializer)
- }
- serializer
- }
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index ed74a31f05..caa06d5b44 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -60,7 +60,7 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C,
- serializer: Serializer = SparkEnv.get.serializerManager.default,
+ serializer: Serializer = SparkEnv.get.serializer,
blockManager: BlockManager = SparkEnv.get.blockManager)
extends Iterable[(K, C)] with Serializable with Logging {
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index abea36f7c8..be6508a40e 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -27,6 +27,9 @@ import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.util.MutablePair
class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
+
+ val conf = new SparkConf(loadDefaults = false)
+
test("groupByKey without compression") {
try {
System.setProperty("spark.shuffle.compress", "false")
@@ -54,7 +57,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
// If the Kryo serializer is not used correctly, the shuffle would fail because the
// default Java serializer cannot handle the non serializable class.
val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)](
- b, new HashPartitioner(NUM_BLOCKS)).setSerializer(classOf[KryoSerializer].getName)
+ b, new HashPartitioner(NUM_BLOCKS)).setSerializer(new KryoSerializer(conf))
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
assert(c.count === 10)
@@ -76,7 +79,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
// If the Kryo serializer is not used correctly, the shuffle would fail because the
// default Java serializer cannot handle the non serializable class.
val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)](
- b, new HashPartitioner(3)).setSerializer(classOf[KryoSerializer].getName)
+ b, new HashPartitioner(3)).setSerializer(new KryoSerializer(conf))
assert(c.count === 10)
}
@@ -92,7 +95,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
// NOTE: The default Java serializer doesn't create zero-sized blocks.
// So, use Kryo
val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10))
- .setSerializer(classOf[KryoSerializer].getName)
+ .setSerializer(new KryoSerializer(conf))
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
assert(c.count === 4)