aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2011-11-02 15:16:02 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2011-11-02 15:16:02 -0700
commitc2b7fd68996f9a91b4409007192badf012dd8f86 (patch)
tree8b6f51aa7f59aa1883b530a021c9a3db6bafd4e2
parentd4c8e69dc7851ffba577d4be3b3daf1723971300 (diff)
downloadspark-c2b7fd68996f9a91b4409007192badf012dd8f86.tar.gz
spark-c2b7fd68996f9a91b4409007192badf012dd8f86.tar.bz2
spark-c2b7fd68996f9a91b4409007192badf012dd8f86.zip
Make parallelize() work efficiently for ranges of Long, Double, etc
(splitting them into sub-ranges). Fixes #87.
-rw-r--r--core/src/main/scala/spark/ParallelCollection.scala23
-rw-r--r--core/src/test/scala/spark/ParallelCollectionSplitSuite.scala34
2 files changed, 52 insertions, 5 deletions
diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala
index b45f29091b..e96f73b3cf 100644
--- a/core/src/main/scala/spark/ParallelCollection.scala
+++ b/core/src/main/scala/spark/ParallelCollection.scala
@@ -1,6 +1,7 @@
package spark
-import java.util.concurrent.atomic.AtomicLong
+import scala.collection.immutable.NumericRange
+import scala.collection.mutable.ArrayBuffer
class ParallelCollectionSplit[T: ClassManifest](
val rddId: Long, val slice: Int, values: Seq[T])
@@ -40,23 +41,35 @@ extends RDD[T](sc) {
}
private object ParallelCollection {
+ // Slice a collection into numSlices sub-collections. One extra thing we do here is
+ // to treat Range collections specially, encoding the slices as other Ranges to
+ // minimize memory cost. This makes it efficient to run Spark over RDDs representing
+ // large sets of numbers.
def slice[T: ClassManifest](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
if (numSlices < 1)
throw new IllegalArgumentException("Positive number of slices required")
seq match {
case r: Range.Inclusive => {
val sign = if (r.step < 0) -1 else 1
- slice(new Range(r.start, r.end + sign, r.step).asInstanceOf[Seq[T]],
- numSlices)
+ slice(new Range(r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices)
}
case r: Range => {
(0 until numSlices).map(i => {
val start = ((i * r.length.toLong) / numSlices).toInt
val end = (((i+1) * r.length.toLong) / numSlices).toInt
- new Range(
- r.start + start * r.step, r.start + end * r.step, r.step)
+ new Range(r.start + start * r.step, r.start + end * r.step, r.step)
}).asInstanceOf[Seq[Seq[T]]]
}
+ case nr: NumericRange[_] => { // For ranges of Long, Double, BigInteger, etc
+ val slices = new ArrayBuffer[Seq[T]](numSlices)
+ val sliceSize = (nr.size + numSlices - 1) / numSlices // Round up to catch everything
+ var r = nr
+ for (i <- 0 until numSlices) {
+ slices += r.take(sliceSize).asInstanceOf[Seq[T]]
+ r = r.drop(sliceSize)
+ }
+ slices
+ }
case _ => {
val array = seq.toArray // To prevent O(n^2) operations for List etc
(0 until numSlices).map(i => {
diff --git a/core/src/test/scala/spark/ParallelCollectionSplitSuite.scala b/core/src/test/scala/spark/ParallelCollectionSplitSuite.scala
index af6ec8bae5..450c69bd58 100644
--- a/core/src/test/scala/spark/ParallelCollectionSplitSuite.scala
+++ b/core/src/test/scala/spark/ParallelCollectionSplitSuite.scala
@@ -1,5 +1,7 @@
package spark
+import scala.collection.immutable.NumericRange
+
import org.scalatest.FunSuite
import org.scalatest.prop.Checkers
import org.scalacheck.Arbitrary._
@@ -158,4 +160,36 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
}
check(prop)
}
+
+ test("exclusive ranges of longs") {
+ val data = 1L until 100L
+ val slices = ParallelCollection.slice(data, 3)
+ assert(slices.size === 3)
+ assert(slices.map(_.size).reduceLeft(_+_) === 99)
+ assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
+ }
+
+ test("inclusive ranges of longs") {
+ val data = 1L to 100L
+ val slices = ParallelCollection.slice(data, 3)
+ assert(slices.size === 3)
+ assert(slices.map(_.size).reduceLeft(_+_) === 100)
+ assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
+ }
+
+ test("exclusive ranges of doubles") {
+ val data = 1.0 until 100.0 by 1.0
+ val slices = ParallelCollection.slice(data, 3)
+ assert(slices.size === 3)
+ assert(slices.map(_.size).reduceLeft(_+_) === 99)
+ assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
+ }
+
+ test("inclusive ranges of doubles") {
+ val data = 1.0 to 100.0 by 1.0
+ val slices = ParallelCollection.slice(data, 3)
+ assert(slices.size === 3)
+ assert(slices.map(_.size).reduceLeft(_+_) === 100)
+ assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
+ }
}