aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorDmitriy Lyubimov <dlyubimov@apache.org>2013-07-22 14:42:47 -0700
committerDmitriy Lyubimov <dlyubimov@apache.org>2013-07-22 14:42:47 -0700
commitb4b230e6067d9a8290b79fe56bbfdea3fb12fdbb (patch)
tree365ee41632952a086a4899a362695aace3a4f398 /core
parent2ab311f4cee3f918dc28daaebd287b11c9f63429 (diff)
downloadspark-b4b230e6067d9a8290b79fe56bbfdea3fb12fdbb.tar.gz
spark-b4b230e6067d9a8290b79fe56bbfdea3fb12fdbb.tar.bz2
spark-b4b230e6067d9a8290b79fe56bbfdea3fb12fdbb.zip
Fixing for LocalScheduler with test, that much works ..
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/scheduler/TaskResult.scala27
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala6
-rw-r--r--core/src/test/scala/spark/KryoSerializerSuite.scala15
3 files changed, 44 insertions, 4 deletions
diff --git a/core/src/main/scala/spark/scheduler/TaskResult.scala b/core/src/main/scala/spark/scheduler/TaskResult.scala
index 6de0aa7adf..0b459ea600 100644
--- a/core/src/main/scala/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/spark/scheduler/TaskResult.scala
@@ -4,6 +4,8 @@ import java.io._
import scala.collection.mutable.Map
import spark.executor.TaskMetrics
+import spark.SparkEnv
+import java.nio.ByteBuffer
// Task result. Also contains updates to accumulator variables.
// TODO: Use of distributed cache to return result is a hack to get around
@@ -13,7 +15,19 @@ class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics:
def this() = this(null.asInstanceOf[T], null, null)
override def writeExternal(out: ObjectOutput) {
- out.writeObject(value)
+
+ val objectSer = SparkEnv.get.serializer.newInstance()
+ val bb = objectSer.serialize(value)
+
+ out.writeInt( bb.remaining())
+ if (bb.hasArray) {
+ out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
+ } else {
+ val bbval = new Array[Byte](bb.remaining())
+ bb.get(bbval)
+ out.write(bbval)
+ }
+
out.writeInt(accumUpdates.size)
for ((key, value) <- accumUpdates) {
out.writeLong(key)
@@ -23,7 +37,16 @@ class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics:
}
override def readExternal(in: ObjectInput) {
- value = in.readObject().asInstanceOf[T]
+
+ //this doesn't work since SparkEnv.get == null
+ // in this context
+ val objectSer = SparkEnv.get.serializer.newInstance()
+
+ val blen = in.readInt()
+ val byteVal = new Array[Byte](blen)
+ in.readFully(byteVal)
+ value = objectSer.deserialize(ByteBuffer.wrap(byteVal))
+
val numUpdates = in.readInt
if (numUpdates == 0) {
accumUpdates = null
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index 93d4318b29..0b800ec740 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -145,6 +145,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
val ser = SparkEnv.get.closureSerializer.newInstance()
+ val objectSer = SparkEnv.get.serializer.newInstance()
try {
Accumulators.clear()
Thread.currentThread().setContextClassLoader(classLoader)
@@ -165,9 +166,9 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
// executor does. This is useful to catch serialization errors early
// on in development (so when users move their local Spark programs
// to the cluster, they don't get surprised by serialization errors).
- val serResult = ser.serialize(result)
+ val serResult = objectSer.serialize(result)
deserializedTask.metrics.get.resultSize = serResult.limit()
- val resultToReturn = ser.deserialize[Any](serResult)
+ val resultToReturn = objectSer.deserialize[Any](serResult)
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
ser.serialize(Accumulators.values))
logInfo("Finished " + taskId)
@@ -218,6 +219,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
val taskSetId = taskIdToTaskSetId(taskId)
val taskSetManager = activeTaskSets(taskSetId)
taskSetTaskIds(taskSetId) -= taskId
+ SparkEnv.set(env)
taskSetManager.statusUpdate(taskId, state, serializedData)
}
}
diff --git a/core/src/test/scala/spark/KryoSerializerSuite.scala b/core/src/test/scala/spark/KryoSerializerSuite.scala
index 327e2ff848..a2f91648b6 100644
--- a/core/src/test/scala/spark/KryoSerializerSuite.scala
+++ b/core/src/test/scala/spark/KryoSerializerSuite.scala
@@ -7,6 +7,7 @@ import org.scalatest.FunSuite
import com.esotericsoftware.kryo._
import SparkContext._
+import spark.test.{ClassWithoutNoArgConstructor, MyRegistrator}
class KryoSerializerSuite extends FunSuite {
test("basic types") {
@@ -109,6 +110,20 @@ class KryoSerializerSuite extends FunSuite {
System.clearProperty("spark.kryo.registrator")
}
+
+ test("kryo-collect") {
+ System.setProperty("spark.serializer", "spark.KryoSerializer")
+ System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName)
+
+ val sc = new SparkContext("local", "kryoTest")
+ val control = 1 :: 2 :: Nil
+ val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)).collect().map(_.x)
+ assert(control == result.toSeq)
+
+ System.clearProperty("spark.kryo.registrator")
+ System.clearProperty("spark.serializer")
+ }
+
}
package test {