diff options
author | Matei Zaharia <matei.zaharia@gmail.com> | 2013-08-06 22:31:02 -0700 |
---|---|---|
committer | Matei Zaharia <matei.zaharia@gmail.com> | 2013-08-06 22:31:02 -0700 |
commit | 6b043a6f1118402df2e0b7fad2cea2dba48f9992 (patch) | |
tree | b2befd2da9f9a13194799bc1104989b591fe71ad /core | |
parent | 7c4b7a53b1b588c1d0d3e00e99d4d7c53dc1da3d (diff) | |
parent | d29ee3689be8b8353d4f2440ac4e0237e98be056 (diff) | |
download | spark-6b043a6f1118402df2e0b7fad2cea2dba48f9992.tar.gz spark-6b043a6f1118402df2e0b7fad2cea2dba48f9992.tar.bz2 spark-6b043a6f1118402df2e0b7fad2cea2dba48f9992.zip |
Merge pull request #724 from dlyubimov/SPARK-826
SPARK-826: fold(), reduce(), collect() always attempt to use java serialization
Diffstat (limited to 'core')
7 files changed, 162 insertions, 22 deletions
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index ef598ae41b..673f9a810d 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -33,8 +33,9 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} -import spark.serializer.SerializerInstance +import spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} import spark.deploy.SparkHadoopUtil +import java.nio.ByteBuffer /** @@ -68,6 +69,47 @@ private object Utils extends Logging { return ois.readObject.asInstanceOf[T] } + /** Serialize via nested stream using specific serializer */ + def serializeViaNestedStream(os: OutputStream, ser: SerializerInstance)(f: SerializationStream => Unit) = { + val osWrapper = ser.serializeStream(new OutputStream { + def write(b: Int) = os.write(b) + + override def write(b: Array[Byte], off: Int, len: Int) = os.write(b, off, len) + }) + try { + f(osWrapper) + } finally { + osWrapper.close() + } + } + + /** Deserialize via nested stream using specific serializer */ + def deserializeViaNestedStream(is: InputStream, ser: SerializerInstance)(f: DeserializationStream => Unit) = { + val isWrapper = ser.deserializeStream(new InputStream { + def read(): Int = is.read() + + override def read(b: Array[Byte], off: Int, len: Int): Int = is.read(b, off, len) + }) + try { + f(isWrapper) + } finally { + isWrapper.close() + } + } + + /** + * Primitive often used when writing {@link java.nio.ByteBuffer} to {@link java.io.DataOutput}. + */ + def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput) = { + if (bb.hasArray) { + out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + } else { + val bbval = new Array[Byte](bb.remaining()) + bb.get(bbval) + out.write(bbval) + } + } + def isAlpha(c: Char): Boolean = { (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') } diff --git a/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala index 16ba0c26f8..33079cd539 100644 --- a/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala @@ -20,13 +20,15 @@ package spark.rdd import scala.collection.immutable.NumericRange import scala.collection.mutable.ArrayBuffer import scala.collection.Map -import spark.{RDD, TaskContext, SparkContext, Partition} +import spark._ +import java.io._ +import scala.Serializable private[spark] class ParallelCollectionPartition[T: ClassManifest]( - val rddId: Long, - val slice: Int, - values: Seq[T]) - extends Partition with Serializable { + var rddId: Long, + var slice: Int, + var values: Seq[T]) + extends Partition with Serializable { def iterator: Iterator[T] = values.iterator @@ -37,15 +39,49 @@ private[spark] class ParallelCollectionPartition[T: ClassManifest]( case _ => false } - override val index: Int = slice + override def index: Int = slice + + @throws(classOf[IOException]) + private def writeObject(out: ObjectOutputStream): Unit = { + + val sfactory = SparkEnv.get.serializer + + // Treat java serializer with default action rather than going thru serialization, to avoid a + // separate serialization header. + + sfactory match { + case js: JavaSerializer => out.defaultWriteObject() + case _ => + out.writeLong(rddId) + out.writeInt(slice) + + val ser = sfactory.newInstance() + Utils.serializeViaNestedStream(out, ser)(_.writeObject(values)) + } + } + + @throws(classOf[IOException]) + private def readObject(in: ObjectInputStream): Unit = { + + val sfactory = SparkEnv.get.serializer + sfactory match { + case js: JavaSerializer => in.defaultReadObject() + case _ => + rddId = in.readLong() + slice = in.readInt() + + val ser = sfactory.newInstance() + Utils.deserializeViaNestedStream(in, ser)(ds => values = ds.readObject()) + } + } } private[spark] class ParallelCollectionRDD[T: ClassManifest]( @transient sc: SparkContext, @transient data: Seq[T], numSlices: Int, - locationPrefs: Map[Int,Seq[String]]) - extends RDD[T](sc, Nil) { + locationPrefs: Map[Int, Seq[String]]) + extends RDD[T](sc, Nil) { // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split // instead. @@ -82,16 +118,17 @@ private object ParallelCollectionRDD { 1 } slice(new Range( - r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices) + r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices) } case r: Range => { (0 until numSlices).map(i => { val start = ((i * r.length.toLong) / numSlices).toInt - val end = (((i+1) * r.length.toLong) / numSlices).toInt + val end = (((i + 1) * r.length.toLong) / numSlices).toInt new Range(r.start + start * r.step, r.start + end * r.step, r.step) }).asInstanceOf[Seq[Seq[T]]] } - case nr: NumericRange[_] => { // For ranges of Long, Double, BigInteger, etc + case nr: NumericRange[_] => { + // For ranges of Long, Double, BigInteger, etc val slices = new ArrayBuffer[Seq[T]](numSlices) val sliceSize = (nr.size + numSlices - 1) / numSlices // Round up to catch everything var r = nr @@ -102,10 +139,10 @@ private object ParallelCollectionRDD { slices } case _ => { - val array = seq.toArray // To prevent O(n^2) operations for List etc + val array = seq.toArray // To prevent O(n^2) operations for List etc (0 until numSlices).map(i => { val start = ((i * array.length.toLong) / numSlices).toInt - val end = (((i+1) * array.length.toLong) / numSlices).toInt + val end = (((i + 1) * array.length.toLong) / numSlices).toInt array.slice(start, end).toSeq }) } diff --git a/core/src/main/scala/spark/scheduler/TaskResult.scala b/core/src/main/scala/spark/scheduler/TaskResult.scala index dc0621ea7b..89793e0e82 100644 --- a/core/src/main/scala/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/spark/scheduler/TaskResult.scala @@ -21,6 +21,8 @@ import java.io._ import scala.collection.mutable.Map import spark.executor.TaskMetrics +import spark.{Utils, SparkEnv} +import java.nio.ByteBuffer // Task result. Also contains updates to accumulator variables. // TODO: Use of distributed cache to return result is a hack to get around @@ -30,7 +32,13 @@ class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: def this() = this(null.asInstanceOf[T], null, null) override def writeExternal(out: ObjectOutput) { - out.writeObject(value) + + val objectSer = SparkEnv.get.serializer.newInstance() + val bb = objectSer.serialize(value) + + out.writeInt(bb.remaining()) + Utils.writeByteBuffer(bb, out) + out.writeInt(accumUpdates.size) for ((key, value) <- accumUpdates) { out.writeLong(key) @@ -40,7 +48,14 @@ class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: } override def readExternal(in: ObjectInput) { - value = in.readObject().asInstanceOf[T] + + val objectSer = SparkEnv.get.serializer.newInstance() + + val blen = in.readInt() + val byteVal = new Array[Byte](blen) + in.readFully(byteVal) + value = objectSer.deserialize(ByteBuffer.wrap(byteVal)) + val numUpdates = in.readInt if (numUpdates == 0) { accumUpdates = null diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala index d2110bd098..7f855cd345 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -92,7 +92,8 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet: val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble // Serializer for closures and tasks. - val ser = SparkEnv.get.closureSerializer.newInstance() + val env = SparkEnv.get + val ser = env.closureSerializer.newInstance() val tasks = taskSet.tasks val numTasks = tasks.length @@ -534,6 +535,7 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet: } override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + SparkEnv.set(env) state match { case TaskState.FINISHED => taskFinished(tid, state, serializedData) diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index bb0c836e86..f274b1a767 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -169,7 +169,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: // Set the Spark execution environment for the worker thread SparkEnv.set(env) val ser = SparkEnv.get.closureSerializer.newInstance() - var attemptedTask: Option[Task[_]] = None + val objectSer = SparkEnv.get.serializer.newInstance() + var attemptedTask: Option[Task[_]] = None val start = System.currentTimeMillis() var taskStart: Long = 0 try { @@ -193,9 +194,9 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: // executor does. This is useful to catch serialization errors early // on in development (so when users move their local Spark programs // to the cluster, they don't get surprised by serialization errors). - val serResult = ser.serialize(result) + val serResult = objectSer.serialize(result) deserializedTask.metrics.get.resultSize = serResult.limit() - val resultToReturn = ser.deserialize[Any](serResult) + val resultToReturn = objectSer.deserialize[Any](serResult) val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( ser.serialize(Accumulators.values)) val serviceTime = System.currentTimeMillis() - taskStart diff --git a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala index 4ab15532cf..c38eeb9e11 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala @@ -42,7 +42,8 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas val taskInfos = new HashMap[Long, TaskInfo] val numTasks = taskSet.tasks.size var numFinished = 0 - val ser = SparkEnv.get.closureSerializer.newInstance() + val env = SparkEnv.get + val ser = env.closureSerializer.newInstance() val copiesRunning = new Array[Int](numTasks) val finished = new Array[Boolean](numTasks) val numFailures = new Array[Int](numTasks) @@ -143,6 +144,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas } override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + SparkEnv.set(env) state match { case TaskState.FINISHED => taskEnded(tid, state, serializedData) diff --git a/core/src/test/scala/spark/KryoSerializerSuite.scala b/core/src/test/scala/spark/KryoSerializerSuite.scala index 30d2d5282b..01390027c8 100644 --- a/core/src/test/scala/spark/KryoSerializerSuite.scala +++ b/core/src/test/scala/spark/KryoSerializerSuite.scala @@ -22,7 +22,9 @@ import scala.collection.mutable import org.scalatest.FunSuite import com.esotericsoftware.kryo._ -class KryoSerializerSuite extends FunSuite { +import KryoTest._ + +class KryoSerializerSuite extends FunSuite with SharedSparkContext { test("basic types") { val ser = (new KryoSerializer).newInstance() def check[T](t: T) { @@ -124,6 +126,45 @@ class KryoSerializerSuite extends FunSuite { System.clearProperty("spark.kryo.registrator") } + + test("kryo with collect") { + val control = 1 :: 2 :: Nil + val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)).collect().map(_.x) + assert(control === result.toSeq) + } + + test("kryo with parallelize") { + val control = 1 :: 2 :: Nil + val result = sc.parallelize(control.map(new ClassWithoutNoArgConstructor(_))).map(_.x).collect() + assert (control === result.toSeq) + } + + test("kryo with reduce") { + val control = 1 :: 2 :: Nil + val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)) + .reduce((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x + assert(control.sum === result) + } + + // TODO: this still doesn't work + ignore("kryo with fold") { + val control = 1 :: 2 :: Nil + val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)) + .fold(new ClassWithoutNoArgConstructor(10))((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x + assert(10 + control.sum === result) + } + + override def beforeAll() { + System.setProperty("spark.serializer", "spark.KryoSerializer") + System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName) + super.beforeAll() + } + + override def afterAll() { + super.afterAll() + System.clearProperty("spark.kryo.registrator") + System.clearProperty("spark.serializer") + } } object KryoTest { |