From 8e5cd041bbf7f802794b8f6e960f702fb59e5863 Mon Sep 17 00:00:00 2001 From: Dmitriy Lyubimov Date: Mon, 29 Jul 2013 18:25:33 -0700 Subject: initial externalization of ParallelCollectionRDD's split --- .../scala/spark/rdd/ParallelCollectionRDD.scala | 70 +++++++++++++++++----- 1 file changed, 55 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala index 16ba0c26f8..104257ac07 100644 --- a/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala @@ -20,13 +20,15 @@ package spark.rdd import scala.collection.immutable.NumericRange import scala.collection.mutable.ArrayBuffer import scala.collection.Map -import spark.{RDD, TaskContext, SparkContext, Partition} +import spark._ +import java.io.{ObjectInput, ObjectOutput, Externalizable} +import java.nio.ByteBuffer private[spark] class ParallelCollectionPartition[T: ClassManifest]( - val rddId: Long, - val slice: Int, - values: Seq[T]) - extends Partition with Serializable { + var rddId: Long, + var slice: Int, + var values: Seq[T]) + extends Partition with Externalizable { def iterator: Iterator[T] = values.iterator @@ -37,14 +39,51 @@ private[spark] class ParallelCollectionPartition[T: ClassManifest]( case _ => false } - override val index: Int = slice + override def index: Int = slice + + override def writeExternal(out: ObjectOutput) { + out.writeLong(rddId) + out.writeInt(slice) + out.writeInt(values.size) + val ser = SparkEnv.get.serializer.newInstance() + values.foreach(x => { + val bb = ser.serialize(x) + out.writeInt(bb.remaining()) + if (bb.hasArray) { + out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + } else { + val b = new Array[Byte](bb.remaining()) + bb.get(b) + out.write(b) + } + }) + } + + override def readExternal(in: ObjectInput) { + rddId = in.readLong() + slice = in.readInt() + val s = in.readInt() + val ser = SparkEnv.get.serializer.newInstance() + var bb = ByteBuffer.allocate(1024) + values = (0 until s).map({ + val s = in.readInt() + if (bb.capacity() < s) { + bb = ByteBuffer.allocate(s) + } else { + bb.clear() + } + in.readFully(bb.array()) + bb.limit(s) + ser.deserialize(bb) + }).toSeq + } } private[spark] class ParallelCollectionRDD[T: ClassManifest]( - @transient sc: SparkContext, - @transient data: Seq[T], - numSlices: Int, - locationPrefs: Map[Int,Seq[String]]) + @transient sc: SparkContext, + @transient data: Seq[T], + numSlices: Int, + locationPrefs: Map[Int, Seq[String]]) extends RDD[T](sc, Nil) { // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split @@ -82,16 +121,17 @@ private object ParallelCollectionRDD { 1 } slice(new Range( - r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices) + r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices) } case r: Range => { (0 until numSlices).map(i => { val start = ((i * r.length.toLong) / numSlices).toInt - val end = (((i+1) * r.length.toLong) / numSlices).toInt + val end = (((i + 1) * r.length.toLong) / numSlices).toInt new Range(r.start + start * r.step, r.start + end * r.step, r.step) }).asInstanceOf[Seq[Seq[T]]] } - case nr: NumericRange[_] => { // For ranges of Long, Double, BigInteger, etc + case nr: NumericRange[_] => { + // For ranges of Long, Double, BigInteger, etc val slices = new ArrayBuffer[Seq[T]](numSlices) val sliceSize = (nr.size + numSlices - 1) / numSlices // Round up to catch everything var r = nr @@ -102,10 +142,10 @@ private object ParallelCollectionRDD { slices } case _ => { - val array = seq.toArray // To prevent O(n^2) operations for List etc + val array = seq.toArray // To prevent O(n^2) operations for List etc (0 until numSlices).map(i => { val start = ((i * array.length.toLong) / numSlices).toInt - val end = (((i+1) * array.length.toLong) / numSlices).toInt + val end = (((i + 1) * array.length.toLong) / numSlices).toInt array.slice(start, end).toSeq }) } -- cgit v1.2.3