diff options
author | Kan Zhang <kzhang@apache.org> | 2014-06-14 14:31:28 -0700 |
---|---|---|
committer | Matei Zaharia <matei@databricks.com> | 2014-06-14 14:31:28 -0700 |
commit | 7dd9fc67a63985493ad0482d307edd56f3af0b9d (patch) | |
tree | 5d266c0558252a193411773a7dc80d38ed92578c | |
parent | b52603b039cdfa0f8e58ef3c6229d79e732ffc58 (diff) | |
download | spark-7dd9fc67a63985493ad0482d307edd56f3af0b9d.tar.gz spark-7dd9fc67a63985493ad0482d307edd56f3af0b9d.tar.bz2 spark-7dd9fc67a63985493ad0482d307edd56f3af0b9d.zip |
[SPARK-1837] NumericRange should be partitioned in the same way as other...
... sequences
Author: Kan Zhang <kzhang@apache.org>
Closes #776 from kanzhang/SPARK-1837 and squashes the following commits:
e48f018 [Kan Zhang] [SPARK-1837] code refactoring
67c33b5 [Kan Zhang] minor change
403f9b1 [Kan Zhang] [SPARK-1837] NumericRange should be partitioned in the same way as other sequences
-rw-r--r-- | core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala | 31 | ||||
-rw-r--r-- | core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala | 18 |
2 files changed, 37 insertions, 12 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 2425929fc7..66c71bf7e8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -117,6 +117,15 @@ private object ParallelCollectionRDD { if (numSlices < 1) { throw new IllegalArgumentException("Positive number of slices required") } + // Sequences need to be sliced at the same set of index positions for operations + // like RDD.zip() to behave as expected + def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = { + (0 until numSlices).iterator.map(i => { + val start = ((i * length) / numSlices).toInt + val end = (((i + 1) * length) / numSlices).toInt + (start, end) + }) + } seq match { case r: Range.Inclusive => { val sign = if (r.step < 0) { @@ -128,18 +137,17 @@ private object ParallelCollectionRDD { r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices) } case r: Range => { - (0 until numSlices).map(i => { - val start = ((i * r.length.toLong) / numSlices).toInt - val end = (((i + 1) * r.length.toLong) / numSlices).toInt - new Range(r.start + start * r.step, r.start + end * r.step, r.step) - }).asInstanceOf[Seq[Seq[T]]] + positions(r.length, numSlices).map({ + case (start, end) => + new Range(r.start + start * r.step, r.start + end * r.step, r.step) + }).toSeq.asInstanceOf[Seq[Seq[T]]] } case nr: NumericRange[_] => { // For ranges of Long, Double, BigInteger, etc val slices = new ArrayBuffer[Seq[T]](numSlices) - val sliceSize = (nr.size + numSlices - 1) / numSlices // Round up to catch everything var r = nr - for (i <- 0 until numSlices) { + for ((start, end) <- positions(nr.length, numSlices)) { + val sliceSize = end - start slices += r.take(sliceSize).asInstanceOf[Seq[T]] r = r.drop(sliceSize) } @@ -147,11 +155,10 @@ private object ParallelCollectionRDD { } case _ => { val array = seq.toArray // To prevent O(n^2) operations for List etc - (0 until numSlices).map(i => { - val start = ((i * array.length.toLong) / numSlices).toInt - val end = (((i + 1) * array.length.toLong) / numSlices).toInt - array.slice(start, end).toSeq - }) + positions(array.length, numSlices).map({ + case (start, end) => + array.slice(start, end).toSeq + }).toSeq } } } diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala index 4df36558b6..1b112f1a41 100644 --- a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala @@ -111,6 +111,24 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices.forall(_.isInstanceOf[Range])) } + test("identical slice sizes between Range and NumericRange") { + val r = ParallelCollectionRDD.slice(1 to 7, 4) + val nr = ParallelCollectionRDD.slice(1L to 7L, 4) + assert(r.size === 4) + for (i <- 0 until r.size) { + assert(r(i).size === nr(i).size) + } + } + + test("identical slice sizes between List and NumericRange") { + val r = ParallelCollectionRDD.slice(List(1, 2), 4) + val nr = ParallelCollectionRDD.slice(1L to 2L, 4) + assert(r.size === 4) + for (i <- 0 until r.size) { + assert(r(i).size === nr(i).size) + } + } + test("large ranges don't overflow") { val N = 100 * 1000 * 1000 val data = 0 until N |