aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorDmitriy Lyubimov <dlyubimov@apache.org>2013-07-29 18:25:33 -0700
committerDmitriy Lyubimov <dlyubimov@apache.org>2013-07-29 19:02:53 -0700
commit8e5cd041bbf7f802794b8f6e960f702fb59e5863 (patch)
treed29935e060b28b178dded0d0ef55e8aaa8899448 /core
parentb241fcfb35aeaedbcbc35df3f12ecc300c4302ab (diff)
downloadspark-8e5cd041bbf7f802794b8f6e960f702fb59e5863.tar.gz
spark-8e5cd041bbf7f802794b8f6e960f702fb59e5863.tar.bz2
spark-8e5cd041bbf7f802794b8f6e960f702fb59e5863.zip
initial externalization of ParallelCollectionRDD's split
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala70
1 files changed, 55 insertions, 15 deletions
diff --git a/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala
index 16ba0c26f8..104257ac07 100644
--- a/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala
+++ b/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala
@@ -20,13 +20,15 @@ package spark.rdd
import scala.collection.immutable.NumericRange
import scala.collection.mutable.ArrayBuffer
import scala.collection.Map
-import spark.{RDD, TaskContext, SparkContext, Partition}
+import spark._
+import java.io.{ObjectInput, ObjectOutput, Externalizable}
+import java.nio.ByteBuffer
private[spark] class ParallelCollectionPartition[T: ClassManifest](
- val rddId: Long,
- val slice: Int,
- values: Seq[T])
- extends Partition with Serializable {
+ var rddId: Long,
+ var slice: Int,
+ var values: Seq[T])
+ extends Partition with Externalizable {
def iterator: Iterator[T] = values.iterator
@@ -37,14 +39,51 @@ private[spark] class ParallelCollectionPartition[T: ClassManifest](
case _ => false
}
- override val index: Int = slice
+ override def index: Int = slice
+
+ override def writeExternal(out: ObjectOutput) {
+ out.writeLong(rddId)
+ out.writeInt(slice)
+ out.writeInt(values.size)
+ val ser = SparkEnv.get.serializer.newInstance()
+ values.foreach(x => {
+ val bb = ser.serialize(x)
+ out.writeInt(bb.remaining())
+ if (bb.hasArray) {
+ out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
+ } else {
+ val b = new Array[Byte](bb.remaining())
+ bb.get(b)
+ out.write(b)
+ }
+ })
+ }
+
+ override def readExternal(in: ObjectInput) {
+ rddId = in.readLong()
+ slice = in.readInt()
+ val s = in.readInt()
+ val ser = SparkEnv.get.serializer.newInstance()
+ var bb = ByteBuffer.allocate(1024)
+ values = (0 until s).map({
+ val s = in.readInt()
+ if (bb.capacity() < s) {
+ bb = ByteBuffer.allocate(s)
+ } else {
+ bb.clear()
+ }
+ in.readFully(bb.array())
+ bb.limit(s)
+ ser.deserialize(bb)
+ }).toSeq
+ }
}
private[spark] class ParallelCollectionRDD[T: ClassManifest](
- @transient sc: SparkContext,
- @transient data: Seq[T],
- numSlices: Int,
- locationPrefs: Map[Int,Seq[String]])
+ @transient sc: SparkContext,
+ @transient data: Seq[T],
+ numSlices: Int,
+ locationPrefs: Map[Int, Seq[String]])
extends RDD[T](sc, Nil) {
// TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets
// cached. It might be worthwhile to write the data to a file in the DFS and read it in the split
@@ -82,16 +121,17 @@ private object ParallelCollectionRDD {
1
}
slice(new Range(
- r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices)
+ r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices)
}
case r: Range => {
(0 until numSlices).map(i => {
val start = ((i * r.length.toLong) / numSlices).toInt
- val end = (((i+1) * r.length.toLong) / numSlices).toInt
+ val end = (((i + 1) * r.length.toLong) / numSlices).toInt
new Range(r.start + start * r.step, r.start + end * r.step, r.step)
}).asInstanceOf[Seq[Seq[T]]]
}
- case nr: NumericRange[_] => { // For ranges of Long, Double, BigInteger, etc
+ case nr: NumericRange[_] => {
+ // For ranges of Long, Double, BigInteger, etc
val slices = new ArrayBuffer[Seq[T]](numSlices)
val sliceSize = (nr.size + numSlices - 1) / numSlices // Round up to catch everything
var r = nr
@@ -102,10 +142,10 @@ private object ParallelCollectionRDD {
slices
}
case _ => {
- val array = seq.toArray // To prevent O(n^2) operations for List etc
+ val array = seq.toArray // To prevent O(n^2) operations for List etc
(0 until numSlices).map(i => {
val start = ((i * array.length.toLong) / numSlices).toInt
- val end = (((i+1) * array.length.toLong) / numSlices).toInt
+ val end = (((i + 1) * array.length.toLong) / numSlices).toInt
array.slice(start, end).toSeq
})
}