diff options
Diffstat (limited to 'core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala')
-rw-r--r-- | core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala | 153 |
1 files changed, 153 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala new file mode 100644 index 0000000000..6dbd4309aa --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.collection.immutable.NumericRange +import scala.collection.mutable.ArrayBuffer +import scala.collection.Map +import org.apache.spark._ +import java.io._ +import scala.Serializable +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.util.Utils + +private[spark] class ParallelCollectionPartition[T: ClassManifest]( + var rddId: Long, + var slice: Int, + var values: Seq[T]) + extends Partition with Serializable { + + def iterator: Iterator[T] = values.iterator + + override def hashCode(): Int = (41 * (41 + rddId) + slice).toInt + + override def equals(other: Any): Boolean = other match { + case that: ParallelCollectionPartition[_] => (this.rddId == that.rddId && this.slice == that.slice) + case _ => false + } + + override def index: Int = slice + + @throws(classOf[IOException]) + private def writeObject(out: ObjectOutputStream): Unit = { + + val sfactory = SparkEnv.get.serializer + + // Treat java serializer with default action rather than going thru serialization, to avoid a + // separate serialization header. + + sfactory match { + case js: JavaSerializer => out.defaultWriteObject() + case _ => + out.writeLong(rddId) + out.writeInt(slice) + + val ser = sfactory.newInstance() + Utils.serializeViaNestedStream(out, ser)(_.writeObject(values)) + } + } + + @throws(classOf[IOException]) + private def readObject(in: ObjectInputStream): Unit = { + + val sfactory = SparkEnv.get.serializer + sfactory match { + case js: JavaSerializer => in.defaultReadObject() + case _ => + rddId = in.readLong() + slice = in.readInt() + + val ser = sfactory.newInstance() + Utils.deserializeViaNestedStream(in, ser)(ds => values = ds.readObject()) + } + } +} + +private[spark] class ParallelCollectionRDD[T: ClassManifest]( + @transient sc: SparkContext, + @transient data: Seq[T], + numSlices: Int, + locationPrefs: Map[Int, Seq[String]]) + extends RDD[T](sc, Nil) { + // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets + // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split + // instead. + // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal. + + override def getPartitions: Array[Partition] = { + val slices = ParallelCollectionRDD.slice(data, numSlices).toArray + slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray + } + + override def compute(s: Partition, context: TaskContext) = + s.asInstanceOf[ParallelCollectionPartition[T]].iterator + + override def getPreferredLocations(s: Partition): Seq[String] = { + locationPrefs.getOrElse(s.index, Nil) + } +} + +private object ParallelCollectionRDD { + /** + * Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range + * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes + * it efficient to run Spark over RDDs representing large sets of numbers. + */ + def slice[T: ClassManifest](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = { + if (numSlices < 1) { + throw new IllegalArgumentException("Positive number of slices required") + } + seq match { + case r: Range.Inclusive => { + val sign = if (r.step < 0) { + -1 + } else { + 1 + } + slice(new Range( + r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices) + } + case r: Range => { + (0 until numSlices).map(i => { + val start = ((i * r.length.toLong) / numSlices).toInt + val end = (((i + 1) * r.length.toLong) / numSlices).toInt + new Range(r.start + start * r.step, r.start + end * r.step, r.step) + }).asInstanceOf[Seq[Seq[T]]] + } + case nr: NumericRange[_] => { + // For ranges of Long, Double, BigInteger, etc + val slices = new ArrayBuffer[Seq[T]](numSlices) + val sliceSize = (nr.size + numSlices - 1) / numSlices // Round up to catch everything + var r = nr + for (i <- 0 until numSlices) { + slices += r.take(sliceSize).asInstanceOf[Seq[T]] + r = r.drop(sliceSize) + } + slices + } + case _ => { + val array = seq.toArray // To prevent O(n^2) operations for List etc + (0 until numSlices).map(i => { + val start = ((i * array.length.toLong) / numSlices).toInt + val end = (((i + 1) * array.length.toLong) / numSlices).toInt + array.slice(start, end).toSeq + }) + } + } + } +} |