diff options
author | Matei Zaharia <matei@eecs.berkeley.edu> | 2012-04-10 14:21:02 -0700 |
---|---|---|
committer | Matei Zaharia <matei@eecs.berkeley.edu> | 2012-04-10 14:21:02 -0700 |
commit | 112655f03201c877b5ff3e43519cde8052909095 (patch) | |
tree | 89aa8c8feaafab600d09141170c6a3eec83bed2a | |
parent | a6339741433ec74e06adc8e876eed163e69706f9 (diff) | |
parent | d295ccb43c0a7e642ffc04a20107fb94ab2392f0 (diff) | |
download | spark-112655f03201c877b5ff3e43519cde8052909095.tar.gz spark-112655f03201c877b5ff3e43519cde8052909095.tar.bz2 spark-112655f03201c877b5ff3e43519cde8052909095.zip |
Merge pull request #121 from rxin/kryo-closure
Added an option (spark.closure.serializer) to specify the serializer for closures.
-rw-r--r-- | bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala | 4 | ||||
-rw-r--r-- | core/src/main/scala/spark/Executor.scala | 13 | ||||
-rw-r--r-- | core/src/main/scala/spark/JavaSerializer.scala | 9 | ||||
-rw-r--r-- | core/src/main/scala/spark/KryoSerializer.scala | 15 | ||||
-rw-r--r-- | core/src/main/scala/spark/LocalScheduler.scala | 10 | ||||
-rw-r--r-- | core/src/main/scala/spark/Serializer.scala | 1 | ||||
-rw-r--r-- | core/src/main/scala/spark/SimpleJob.scala | 19 | ||||
-rw-r--r-- | core/src/main/scala/spark/SparkEnv.scala | 15 | ||||
-rw-r--r-- | core/src/main/scala/spark/Utils.scala | 1 |
9 files changed, 73 insertions, 14 deletions
diff --git a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala index 2e38376499..7084ff97d9 100644 --- a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala +++ b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala @@ -126,6 +126,10 @@ class WPRSerializerInstance extends SerializerInstance { throw new UnsupportedOperationException() } + def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { + throw new UnsupportedOperationException() + } + def outputStream(s: OutputStream): SerializationStream = { new WPRSerializationStream(s) } diff --git a/core/src/main/scala/spark/Executor.scala b/core/src/main/scala/spark/Executor.scala index 71a2ded7e7..c1795e02a4 100644 --- a/core/src/main/scala/spark/Executor.scala +++ b/core/src/main/scala/spark/Executor.scala @@ -57,16 +57,17 @@ class Executor extends org.apache.mesos.Executor with Logging { extends Runnable { override def run() = { val tid = desc.getTaskId.getValue + SparkEnv.set(env) + Thread.currentThread.setContextClassLoader(classLoader) + val ser = SparkEnv.get.closureSerializer.newInstance() logInfo("Running task ID " + tid) d.sendStatusUpdate(TaskStatus.newBuilder() .setTaskId(desc.getTaskId) .setState(TaskState.TASK_RUNNING) .build()) try { - SparkEnv.set(env) - Thread.currentThread.setContextClassLoader(classLoader) Accumulators.clear - val task = Utils.deserialize[Task[Any]](desc.getData.toByteArray, classLoader) + val task = ser.deserialize[Task[Any]](desc.getData.toByteArray, classLoader) for (gen <- task.generation) {// Update generation if any is set env.mapOutputTracker.updateGeneration(gen) } @@ -76,7 +77,7 @@ class Executor extends org.apache.mesos.Executor with Logging { d.sendStatusUpdate(TaskStatus.newBuilder() .setTaskId(desc.getTaskId) .setState(TaskState.TASK_FINISHED) - .setData(ByteString.copyFrom(Utils.serialize(result))) + .setData(ByteString.copyFrom(ser.serialize(result))) .build()) logInfo("Finished task ID " + tid) } catch { @@ -85,7 +86,7 @@ class Executor extends org.apache.mesos.Executor with Logging { d.sendStatusUpdate(TaskStatus.newBuilder() .setTaskId(desc.getTaskId) .setState(TaskState.TASK_FAILED) - .setData(ByteString.copyFrom(Utils.serialize(reason))) + .setData(ByteString.copyFrom(ser.serialize(reason))) .build()) } case t: Throwable => { @@ -93,7 +94,7 @@ class Executor extends org.apache.mesos.Executor with Logging { d.sendStatusUpdate(TaskStatus.newBuilder() .setTaskId(desc.getTaskId) .setState(TaskState.TASK_FAILED) - .setData(ByteString.copyFrom(Utils.serialize(reason))) + .setData(ByteString.copyFrom(ser.serialize(reason))) .build()) // TODO: Handle errors in tasks less dramatically diff --git a/core/src/main/scala/spark/JavaSerializer.scala b/core/src/main/scala/spark/JavaSerializer.scala index e7cd4364ee..80f615eeb0 100644 --- a/core/src/main/scala/spark/JavaSerializer.scala +++ b/core/src/main/scala/spark/JavaSerializer.scala @@ -34,6 +34,15 @@ class JavaSerializerInstance extends SerializerInstance { in.readObject().asInstanceOf[T] } + def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { + val bis = new ByteArrayInputStream(bytes) + val ois = new ObjectInputStream(bis) { + override def resolveClass(desc: ObjectStreamClass) = + Class.forName(desc.getName, false, loader) + } + return ois.readObject.asInstanceOf[T] + } + def outputStream(s: OutputStream): SerializationStream = { new JavaSerializationStream(s) } diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 7d25b965d2..5693613d6d 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -9,6 +9,7 @@ import scala.collection.mutable import com.esotericsoftware.kryo._ import com.esotericsoftware.kryo.{Serializer => KSerializer} +import com.esotericsoftware.kryo.serialize.ClassSerializer import de.javakaffee.kryoserializers.KryoReflectionFactorySupport /** @@ -100,6 +101,14 @@ class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { buf.readClassAndObject(bytes).asInstanceOf[T] } + def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { + val oldClassLoader = ks.kryo.getClassLoader + ks.kryo.setClassLoader(loader) + val obj = buf.readClassAndObject(bytes).asInstanceOf[T] + ks.kryo.setClassLoader(oldClassLoader) + obj + } + def outputStream(s: OutputStream): SerializationStream = { new KryoSerializationStream(ks.kryo, ks.threadByteBuf.get(), s) } @@ -129,6 +138,8 @@ class KryoSerializer extends Serializer with Logging { } def createKryo(): Kryo = { + // This is used so we can serialize/deserialize objects without a zero-arg + // constructor. val kryo = new KryoReflectionFactorySupport() // Register some commonly used classes @@ -150,6 +161,10 @@ class KryoSerializer extends Serializer with Logging { kryo.register(obj.getClass) } + // Register the following classes for passing closures. + kryo.register(classOf[Class[_]], new ClassSerializer(kryo)) + kryo.setRegistrationOptional(true) + // Register some commonly used Scala singleton objects. Because these // are singletons, we must return the exact same local object when we // deserialize rather than returning a clone as FieldSerializer would. diff --git a/core/src/main/scala/spark/LocalScheduler.scala b/core/src/main/scala/spark/LocalScheduler.scala index 0cbc68ffc5..8972d6c290 100644 --- a/core/src/main/scala/spark/LocalScheduler.scala +++ b/core/src/main/scala/spark/LocalScheduler.scala @@ -38,9 +38,13 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule // Serialize and deserialize the task so that accumulators are changed to thread-local ones; // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. Accumulators.clear - val bytes = Utils.serialize(task) - logInfo("Size of task " + idInJob + " is " + bytes.size + " bytes") - val deserializedTask = Utils.deserialize[Task[_]]( + val ser = SparkEnv.get.closureSerializer.newInstance() + val startTime = System.currentTimeMillis + val bytes = ser.serialize(task) + val timeTaken = System.currentTimeMillis - startTime + logInfo("Size of task %d is %d bytes and took %d ms to serialize by %s" + .format(idInJob, bytes.size, timeTaken, ser.getClass.getName)) + val deserializedTask = ser.deserialize[Task[_]]( bytes, Thread.currentThread.getContextClassLoader) val result: Any = deserializedTask.run(attemptId) val accumUpdates = Accumulators.values diff --git a/core/src/main/scala/spark/Serializer.scala b/core/src/main/scala/spark/Serializer.scala index 15fca9fcda..2429bbfeb9 100644 --- a/core/src/main/scala/spark/Serializer.scala +++ b/core/src/main/scala/spark/Serializer.scala @@ -16,6 +16,7 @@ trait Serializer { trait SerializerInstance { def serialize[T](t: T): Array[Byte] def deserialize[T](bytes: Array[Byte]): T + def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T def outputStream(s: OutputStream): SerializationStream def inputStream(s: InputStream): DeserializationStream } diff --git a/core/src/main/scala/spark/SimpleJob.scala b/core/src/main/scala/spark/SimpleJob.scala index 5e42ae6ecd..b221c2e309 100644 --- a/core/src/main/scala/spark/SimpleJob.scala +++ b/core/src/main/scala/spark/SimpleJob.scala @@ -30,6 +30,9 @@ class SimpleJob( // Maximum times a task is allowed to fail before failing the job val MAX_TASK_FAILURES = 4 + // Serializer for closures and tasks. + val ser = SparkEnv.get.closureSerializer.newInstance() + val callingThread = Thread.currentThread val tasks = tasksSeq.toArray val numTasks = tasks.length @@ -170,8 +173,14 @@ class SimpleJob( .setType(Resource.Type.SCALAR) .setScalar(Resource.Scalar.newBuilder().setValue(CPUS_PER_TASK).build()) .build() - val serializedTask = Utils.serialize(task) - logDebug("Serialized size: " + serializedTask.size) + + val startTime = System.currentTimeMillis + val serializedTask = ser.serialize(task) + val timeTaken = System.currentTimeMillis - startTime + + logInfo("Size of task %d:%d is %d bytes and took %d ms to serialize by %s" + .format(jobId, index, serializedTask.size, timeTaken, ser.getClass.getName)) + val taskName = "task %d:%d".format(jobId, index) return Some(TaskDescription.newBuilder() .setTaskId(taskId) @@ -208,7 +217,8 @@ class SimpleJob( tasksFinished += 1 logInfo("Finished TID %s (progress: %d/%d)".format(tid, tasksFinished, numTasks)) // Deserialize task result - val result = Utils.deserialize[TaskResult[_]](status.getData.toByteArray) + val result = ser.deserialize[TaskResult[_]]( + status.getData.toByteArray) sched.taskEnded(tasks(index), Success, result.value, result.accumUpdates) // Mark finished and stop if we've finished all the tasks finished(index) = true @@ -230,7 +240,8 @@ class SimpleJob( // Check if the problem is a map output fetch failure. In that case, this // task will never succeed on any node, so tell the scheduler about it. if (status.getData != null && status.getData.size > 0) { - val reason = Utils.deserialize[TaskEndReason](status.getData.toByteArray) + val reason = ser.deserialize[TaskEndReason]( + status.getData.toByteArray) reason match { case fetchFailed: FetchFailed => logInfo("Loss was due to fetch failure from " + fetchFailed.serverUri) diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index e2d1562e35..cd752f8b65 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -3,6 +3,7 @@ package spark class SparkEnv ( val cache: Cache, val serializer: Serializer, + val closureSerializer: Serializer, val cacheTracker: CacheTracker, val mapOutputTracker: MapOutputTracker, val shuffleFetcher: ShuffleFetcher, @@ -27,6 +28,11 @@ object SparkEnv { val serializerClass = System.getProperty("spark.serializer", "spark.JavaSerializer") val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer] + val closureSerializerClass = + System.getProperty("spark.closure.serializer", "spark.JavaSerializer") + val closureSerializer = + Class.forName(closureSerializerClass).newInstance().asInstanceOf[Serializer] + val cacheTracker = new CacheTracker(isMaster, cache) val mapOutputTracker = new MapOutputTracker(isMaster) @@ -38,6 +44,13 @@ object SparkEnv { val shuffleMgr = new ShuffleManager() - new SparkEnv(cache, serializer, cacheTracker, mapOutputTracker, shuffleFetcher, shuffleMgr) + new SparkEnv( + cache, + serializer, + closureSerializer, + cacheTracker, + mapOutputTracker, + shuffleFetcher, + shuffleMgr) } } diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 58b5fa6bbd..55f2e0691d 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -12,6 +12,7 @@ import scala.util.Random * Various utility methods used by Spark. */ object Utils { + def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() val oos = new ObjectOutputStream(bos) |