aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2012-04-10 14:21:02 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2012-04-10 14:21:02 -0700
commit112655f03201c877b5ff3e43519cde8052909095 (patch)
tree89aa8c8feaafab600d09141170c6a3eec83bed2a
parenta6339741433ec74e06adc8e876eed163e69706f9 (diff)
parentd295ccb43c0a7e642ffc04a20107fb94ab2392f0 (diff)
downloadspark-112655f03201c877b5ff3e43519cde8052909095.tar.gz
spark-112655f03201c877b5ff3e43519cde8052909095.tar.bz2
spark-112655f03201c877b5ff3e43519cde8052909095.zip
Merge pull request #121 from rxin/kryo-closure
Added an option (spark.closure.serializer) to specify the serializer for closures.
-rw-r--r--bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala4
-rw-r--r--core/src/main/scala/spark/Executor.scala13
-rw-r--r--core/src/main/scala/spark/JavaSerializer.scala9
-rw-r--r--core/src/main/scala/spark/KryoSerializer.scala15
-rw-r--r--core/src/main/scala/spark/LocalScheduler.scala10
-rw-r--r--core/src/main/scala/spark/Serializer.scala1
-rw-r--r--core/src/main/scala/spark/SimpleJob.scala19
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala15
-rw-r--r--core/src/main/scala/spark/Utils.scala1
9 files changed, 73 insertions, 14 deletions
diff --git a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
index 2e38376499..7084ff97d9 100644
--- a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
+++ b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
@@ -126,6 +126,10 @@ class WPRSerializerInstance extends SerializerInstance {
throw new UnsupportedOperationException()
}
+ def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
+ throw new UnsupportedOperationException()
+ }
+
def outputStream(s: OutputStream): SerializationStream = {
new WPRSerializationStream(s)
}
diff --git a/core/src/main/scala/spark/Executor.scala b/core/src/main/scala/spark/Executor.scala
index 71a2ded7e7..c1795e02a4 100644
--- a/core/src/main/scala/spark/Executor.scala
+++ b/core/src/main/scala/spark/Executor.scala
@@ -57,16 +57,17 @@ class Executor extends org.apache.mesos.Executor with Logging {
extends Runnable {
override def run() = {
val tid = desc.getTaskId.getValue
+ SparkEnv.set(env)
+ Thread.currentThread.setContextClassLoader(classLoader)
+ val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo("Running task ID " + tid)
d.sendStatusUpdate(TaskStatus.newBuilder()
.setTaskId(desc.getTaskId)
.setState(TaskState.TASK_RUNNING)
.build())
try {
- SparkEnv.set(env)
- Thread.currentThread.setContextClassLoader(classLoader)
Accumulators.clear
- val task = Utils.deserialize[Task[Any]](desc.getData.toByteArray, classLoader)
+ val task = ser.deserialize[Task[Any]](desc.getData.toByteArray, classLoader)
for (gen <- task.generation) {// Update generation if any is set
env.mapOutputTracker.updateGeneration(gen)
}
@@ -76,7 +77,7 @@ class Executor extends org.apache.mesos.Executor with Logging {
d.sendStatusUpdate(TaskStatus.newBuilder()
.setTaskId(desc.getTaskId)
.setState(TaskState.TASK_FINISHED)
- .setData(ByteString.copyFrom(Utils.serialize(result)))
+ .setData(ByteString.copyFrom(ser.serialize(result)))
.build())
logInfo("Finished task ID " + tid)
} catch {
@@ -85,7 +86,7 @@ class Executor extends org.apache.mesos.Executor with Logging {
d.sendStatusUpdate(TaskStatus.newBuilder()
.setTaskId(desc.getTaskId)
.setState(TaskState.TASK_FAILED)
- .setData(ByteString.copyFrom(Utils.serialize(reason)))
+ .setData(ByteString.copyFrom(ser.serialize(reason)))
.build())
}
case t: Throwable => {
@@ -93,7 +94,7 @@ class Executor extends org.apache.mesos.Executor with Logging {
d.sendStatusUpdate(TaskStatus.newBuilder()
.setTaskId(desc.getTaskId)
.setState(TaskState.TASK_FAILED)
- .setData(ByteString.copyFrom(Utils.serialize(reason)))
+ .setData(ByteString.copyFrom(ser.serialize(reason)))
.build())
// TODO: Handle errors in tasks less dramatically
diff --git a/core/src/main/scala/spark/JavaSerializer.scala b/core/src/main/scala/spark/JavaSerializer.scala
index e7cd4364ee..80f615eeb0 100644
--- a/core/src/main/scala/spark/JavaSerializer.scala
+++ b/core/src/main/scala/spark/JavaSerializer.scala
@@ -34,6 +34,15 @@ class JavaSerializerInstance extends SerializerInstance {
in.readObject().asInstanceOf[T]
}
+ def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
+ val bis = new ByteArrayInputStream(bytes)
+ val ois = new ObjectInputStream(bis) {
+ override def resolveClass(desc: ObjectStreamClass) =
+ Class.forName(desc.getName, false, loader)
+ }
+ return ois.readObject.asInstanceOf[T]
+ }
+
def outputStream(s: OutputStream): SerializationStream = {
new JavaSerializationStream(s)
}
diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala
index 7d25b965d2..5693613d6d 100644
--- a/core/src/main/scala/spark/KryoSerializer.scala
+++ b/core/src/main/scala/spark/KryoSerializer.scala
@@ -9,6 +9,7 @@ import scala.collection.mutable
import com.esotericsoftware.kryo._
import com.esotericsoftware.kryo.{Serializer => KSerializer}
+import com.esotericsoftware.kryo.serialize.ClassSerializer
import de.javakaffee.kryoserializers.KryoReflectionFactorySupport
/**
@@ -100,6 +101,14 @@ class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
buf.readClassAndObject(bytes).asInstanceOf[T]
}
+ def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
+ val oldClassLoader = ks.kryo.getClassLoader
+ ks.kryo.setClassLoader(loader)
+ val obj = buf.readClassAndObject(bytes).asInstanceOf[T]
+ ks.kryo.setClassLoader(oldClassLoader)
+ obj
+ }
+
def outputStream(s: OutputStream): SerializationStream = {
new KryoSerializationStream(ks.kryo, ks.threadByteBuf.get(), s)
}
@@ -129,6 +138,8 @@ class KryoSerializer extends Serializer with Logging {
}
def createKryo(): Kryo = {
+ // This is used so we can serialize/deserialize objects without a zero-arg
+ // constructor.
val kryo = new KryoReflectionFactorySupport()
// Register some commonly used classes
@@ -150,6 +161,10 @@ class KryoSerializer extends Serializer with Logging {
kryo.register(obj.getClass)
}
+ // Register the following classes for passing closures.
+ kryo.register(classOf[Class[_]], new ClassSerializer(kryo))
+ kryo.setRegistrationOptional(true)
+
// Register some commonly used Scala singleton objects. Because these
// are singletons, we must return the exact same local object when we
// deserialize rather than returning a clone as FieldSerializer would.
diff --git a/core/src/main/scala/spark/LocalScheduler.scala b/core/src/main/scala/spark/LocalScheduler.scala
index 0cbc68ffc5..8972d6c290 100644
--- a/core/src/main/scala/spark/LocalScheduler.scala
+++ b/core/src/main/scala/spark/LocalScheduler.scala
@@ -38,9 +38,13 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule
// Serialize and deserialize the task so that accumulators are changed to thread-local ones;
// this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
Accumulators.clear
- val bytes = Utils.serialize(task)
- logInfo("Size of task " + idInJob + " is " + bytes.size + " bytes")
- val deserializedTask = Utils.deserialize[Task[_]](
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+ val startTime = System.currentTimeMillis
+ val bytes = ser.serialize(task)
+ val timeTaken = System.currentTimeMillis - startTime
+ logInfo("Size of task %d is %d bytes and took %d ms to serialize by %s"
+ .format(idInJob, bytes.size, timeTaken, ser.getClass.getName))
+ val deserializedTask = ser.deserialize[Task[_]](
bytes, Thread.currentThread.getContextClassLoader)
val result: Any = deserializedTask.run(attemptId)
val accumUpdates = Accumulators.values
diff --git a/core/src/main/scala/spark/Serializer.scala b/core/src/main/scala/spark/Serializer.scala
index 15fca9fcda..2429bbfeb9 100644
--- a/core/src/main/scala/spark/Serializer.scala
+++ b/core/src/main/scala/spark/Serializer.scala
@@ -16,6 +16,7 @@ trait Serializer {
trait SerializerInstance {
def serialize[T](t: T): Array[Byte]
def deserialize[T](bytes: Array[Byte]): T
+ def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T
def outputStream(s: OutputStream): SerializationStream
def inputStream(s: InputStream): DeserializationStream
}
diff --git a/core/src/main/scala/spark/SimpleJob.scala b/core/src/main/scala/spark/SimpleJob.scala
index 5e42ae6ecd..b221c2e309 100644
--- a/core/src/main/scala/spark/SimpleJob.scala
+++ b/core/src/main/scala/spark/SimpleJob.scala
@@ -30,6 +30,9 @@ class SimpleJob(
// Maximum times a task is allowed to fail before failing the job
val MAX_TASK_FAILURES = 4
+ // Serializer for closures and tasks.
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+
val callingThread = Thread.currentThread
val tasks = tasksSeq.toArray
val numTasks = tasks.length
@@ -170,8 +173,14 @@ class SimpleJob(
.setType(Resource.Type.SCALAR)
.setScalar(Resource.Scalar.newBuilder().setValue(CPUS_PER_TASK).build())
.build()
- val serializedTask = Utils.serialize(task)
- logDebug("Serialized size: " + serializedTask.size)
+
+ val startTime = System.currentTimeMillis
+ val serializedTask = ser.serialize(task)
+ val timeTaken = System.currentTimeMillis - startTime
+
+ logInfo("Size of task %d:%d is %d bytes and took %d ms to serialize by %s"
+ .format(jobId, index, serializedTask.size, timeTaken, ser.getClass.getName))
+
val taskName = "task %d:%d".format(jobId, index)
return Some(TaskDescription.newBuilder()
.setTaskId(taskId)
@@ -208,7 +217,8 @@ class SimpleJob(
tasksFinished += 1
logInfo("Finished TID %s (progress: %d/%d)".format(tid, tasksFinished, numTasks))
// Deserialize task result
- val result = Utils.deserialize[TaskResult[_]](status.getData.toByteArray)
+ val result = ser.deserialize[TaskResult[_]](
+ status.getData.toByteArray)
sched.taskEnded(tasks(index), Success, result.value, result.accumUpdates)
// Mark finished and stop if we've finished all the tasks
finished(index) = true
@@ -230,7 +240,8 @@ class SimpleJob(
// Check if the problem is a map output fetch failure. In that case, this
// task will never succeed on any node, so tell the scheduler about it.
if (status.getData != null && status.getData.size > 0) {
- val reason = Utils.deserialize[TaskEndReason](status.getData.toByteArray)
+ val reason = ser.deserialize[TaskEndReason](
+ status.getData.toByteArray)
reason match {
case fetchFailed: FetchFailed =>
logInfo("Loss was due to fetch failure from " + fetchFailed.serverUri)
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index e2d1562e35..cd752f8b65 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -3,6 +3,7 @@ package spark
class SparkEnv (
val cache: Cache,
val serializer: Serializer,
+ val closureSerializer: Serializer,
val cacheTracker: CacheTracker,
val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher,
@@ -27,6 +28,11 @@ object SparkEnv {
val serializerClass = System.getProperty("spark.serializer", "spark.JavaSerializer")
val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer]
+ val closureSerializerClass =
+ System.getProperty("spark.closure.serializer", "spark.JavaSerializer")
+ val closureSerializer =
+ Class.forName(closureSerializerClass).newInstance().asInstanceOf[Serializer]
+
val cacheTracker = new CacheTracker(isMaster, cache)
val mapOutputTracker = new MapOutputTracker(isMaster)
@@ -38,6 +44,13 @@ object SparkEnv {
val shuffleMgr = new ShuffleManager()
- new SparkEnv(cache, serializer, cacheTracker, mapOutputTracker, shuffleFetcher, shuffleMgr)
+ new SparkEnv(
+ cache,
+ serializer,
+ closureSerializer,
+ cacheTracker,
+ mapOutputTracker,
+ shuffleFetcher,
+ shuffleMgr)
}
}
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index 58b5fa6bbd..55f2e0691d 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -12,6 +12,7 @@ import scala.util.Random
* Various utility methods used by Spark.
*/
object Utils {
+
def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
val oos = new ObjectOutputStream(bos)