diff options
author | Dmitriy Lyubimov <dlyubimov@apache.org> | 2013-07-22 14:42:47 -0700 |
---|---|---|
committer | Dmitriy Lyubimov <dlyubimov@apache.org> | 2013-07-22 14:42:47 -0700 |
commit | b4b230e6067d9a8290b79fe56bbfdea3fb12fdbb (patch) | |
tree | 365ee41632952a086a4899a362695aace3a4f398 /core | |
parent | 2ab311f4cee3f918dc28daaebd287b11c9f63429 (diff) | |
download | spark-b4b230e6067d9a8290b79fe56bbfdea3fb12fdbb.tar.gz spark-b4b230e6067d9a8290b79fe56bbfdea3fb12fdbb.tar.bz2 spark-b4b230e6067d9a8290b79fe56bbfdea3fb12fdbb.zip |
Fixing for LocalScheduler with test, that much works ..
Diffstat (limited to 'core')
3 files changed, 44 insertions, 4 deletions
diff --git a/core/src/main/scala/spark/scheduler/TaskResult.scala b/core/src/main/scala/spark/scheduler/TaskResult.scala index 6de0aa7adf..0b459ea600 100644 --- a/core/src/main/scala/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/spark/scheduler/TaskResult.scala @@ -4,6 +4,8 @@ import java.io._ import scala.collection.mutable.Map import spark.executor.TaskMetrics +import spark.SparkEnv +import java.nio.ByteBuffer // Task result. Also contains updates to accumulator variables. // TODO: Use of distributed cache to return result is a hack to get around @@ -13,7 +15,19 @@ class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: def this() = this(null.asInstanceOf[T], null, null) override def writeExternal(out: ObjectOutput) { - out.writeObject(value) + + val objectSer = SparkEnv.get.serializer.newInstance() + val bb = objectSer.serialize(value) + + out.writeInt( bb.remaining()) + if (bb.hasArray) { + out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + } else { + val bbval = new Array[Byte](bb.remaining()) + bb.get(bbval) + out.write(bbval) + } + out.writeInt(accumUpdates.size) for ((key, value) <- accumUpdates) { out.writeLong(key) @@ -23,7 +37,16 @@ class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: } override def readExternal(in: ObjectInput) { - value = in.readObject().asInstanceOf[T] + + //this doesn't work since SparkEnv.get == null + // in this context + val objectSer = SparkEnv.get.serializer.newInstance() + + val blen = in.readInt() + val byteVal = new Array[Byte](blen) + in.readFully(byteVal) + value = objectSer.deserialize(ByteBuffer.wrap(byteVal)) + val numUpdates = in.readInt if (numUpdates == 0) { accumUpdates = null diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index 93d4318b29..0b800ec740 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -145,6 +145,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: // Set the Spark execution environment for the worker thread SparkEnv.set(env) val ser = SparkEnv.get.closureSerializer.newInstance() + val objectSer = SparkEnv.get.serializer.newInstance() try { Accumulators.clear() Thread.currentThread().setContextClassLoader(classLoader) @@ -165,9 +166,9 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: // executor does. This is useful to catch serialization errors early // on in development (so when users move their local Spark programs // to the cluster, they don't get surprised by serialization errors). - val serResult = ser.serialize(result) + val serResult = objectSer.serialize(result) deserializedTask.metrics.get.resultSize = serResult.limit() - val resultToReturn = ser.deserialize[Any](serResult) + val resultToReturn = objectSer.deserialize[Any](serResult) val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( ser.serialize(Accumulators.values)) logInfo("Finished " + taskId) @@ -218,6 +219,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: val taskSetId = taskIdToTaskSetId(taskId) val taskSetManager = activeTaskSets(taskSetId) taskSetTaskIds(taskSetId) -= taskId + SparkEnv.set(env) taskSetManager.statusUpdate(taskId, state, serializedData) } } diff --git a/core/src/test/scala/spark/KryoSerializerSuite.scala b/core/src/test/scala/spark/KryoSerializerSuite.scala index 327e2ff848..a2f91648b6 100644 --- a/core/src/test/scala/spark/KryoSerializerSuite.scala +++ b/core/src/test/scala/spark/KryoSerializerSuite.scala @@ -7,6 +7,7 @@ import org.scalatest.FunSuite import com.esotericsoftware.kryo._ import SparkContext._ +import spark.test.{ClassWithoutNoArgConstructor, MyRegistrator} class KryoSerializerSuite extends FunSuite { test("basic types") { @@ -109,6 +110,20 @@ class KryoSerializerSuite extends FunSuite { System.clearProperty("spark.kryo.registrator") } + + test("kryo-collect") { + System.setProperty("spark.serializer", "spark.KryoSerializer") + System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName) + + val sc = new SparkContext("local", "kryoTest") + val control = 1 :: 2 :: Nil + val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)).collect().map(_.x) + assert(control == result.toSeq) + + System.clearProperty("spark.kryo.registrator") + System.clearProperty("spark.serializer") + } + } package test { |