diff options
5 files changed, 96 insertions, 22 deletions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index ca7cdd622a..498b2718c4 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -641,7 +641,7 @@ abstract class RDD[T: ClassManifest]( */ def fold(zeroValue: T)(op: (T, T) => T): T = { // Clone the zero value since we will also be serializing it as part of tasks - var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) + var jobResult = Utils.clone(zeroValue, sc.env.serializer.newInstance()) val cleanOp = sc.clean(op) val foldPartition = (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp) val mergeResult = (index: Int, taskResult: T) => jobResult = op(jobResult, taskResult) diff --git a/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala index 4842e2fb35..672c623537 100644 --- a/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala @@ -23,7 +23,7 @@ import scala.collection.Map import spark._ import java.io._ import scala.Serializable -import java.nio.ByteBuffer +import util.{NestedInputStream, NestedOutputStream} private[spark] class ParallelCollectionPartition[T: ClassManifest]( var rddId: Long, @@ -57,12 +57,12 @@ private[spark] class ParallelCollectionPartition[T: ClassManifest]( out.writeInt(slice) val ser = sfactory.newInstance() - out.writeInt(values.size) - values.foreach(v => { - val bb = ser.serialize(v) - out.writeInt(bb.remaining()) - Utils.writeByteBuffer(bb, out) - }) + val ssout = ser.serializeStream(new NestedOutputStream(out)) + try { + ssout.writeObject(values) + } finally { + ssout.close() + } } } @@ -77,20 +77,12 @@ private[spark] class ParallelCollectionPartition[T: ClassManifest]( slice = in.readInt() val ser = sfactory.newInstance() - val s = in.readInt() - var bb: ByteBuffer = null - values = (0 until s).map(i => { - val len = in.readInt() - if (bb == null || bb.capacity < len) { - bb = ByteBuffer.allocate(len) - } else { - bb.clear - } - - in.readFully(bb.array(), 0, len); - bb.limit(len) - ser.deserialize(bb): T - }).toSeq + val ssin = ser.deserializeStream(new NestedInputStream(in)) + try { + values = ssin.readObject() + } finally { + ssin.close() + } } } } diff --git a/core/src/main/scala/spark/util/NestedInputStream.scala b/core/src/main/scala/spark/util/NestedInputStream.scala new file mode 100644 index 0000000000..5612f57dcc --- /dev/null +++ b/core/src/main/scala/spark/util/NestedInputStream.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.util + +import java.io.InputStream + +/** + * Nested input stream: insert nested streams into outer streams without closing them. + * <p> + * This stream doesn't really provide its own nested envelope and does not put an EOF marker. It is + * still a responsibility of outer code to write a stop token. In fact, the only purpose of this + * wrapper stream is to prevent outer stream (such as JavaSerializer's stream) from closing the + * output. + */ +class NestedInputStream(val in: InputStream) extends InputStream { + override def read(): Int = in.read() + override def read(b: Array[Byte], off: Int, len: Int): Int = in.read(b, off, len) +} diff --git a/core/src/main/scala/spark/util/NetstedOutputStream.scala b/core/src/main/scala/spark/util/NetstedOutputStream.scala new file mode 100644 index 0000000000..f25dcdce8f --- /dev/null +++ b/core/src/main/scala/spark/util/NetstedOutputStream.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.util + +import java.io.OutputStream + +/** + * Nested output stream: insert nested streams into outer streams without closing them. + * <p> + * This stream doesn't really provide its own nested envelope and does not put an EOF marker. It is + * still a responsibility of outer code to write a stop token. In fact, the only purpose of this + * wrapper stream is to prevent outer stream (such as JavaSerializer's stream) from closing the + * output. + */ +class NestedOutputStream(val out: OutputStream) extends OutputStream { + override def write(b: Int) = out.write(b) + override def write(b: Array[Byte], off: Int, len: Int) = out.write(b, off, len) + +} diff --git a/core/src/test/scala/spark/KryoSerializerSuite.scala b/core/src/test/scala/spark/KryoSerializerSuite.scala index 158bf63132..9a667bd2fb 100644 --- a/core/src/test/scala/spark/KryoSerializerSuite.scala +++ b/core/src/test/scala/spark/KryoSerializerSuite.scala @@ -140,6 +140,21 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { assert (control === result.toSeq) } + test("kryo with reduce") { + val control = 1 :: 2 :: Nil + val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)) + .reduce((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x + assert(control.sum === result) + } + + // TODO: this still doesn't work +// test("kryo with fold") { +// val control = 1 :: 2 :: Nil +// val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)) +// .fold(new ClassWithoutNoArgConstructor(10))((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x +// assert(10 + control.sum === result) +// } + override def beforeAll() { System.setProperty("spark.serializer", "spark.KryoSerializer") System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName) |