aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorAaron Josephs <ajoseph4@binghamton.edu>2015-02-22 22:09:06 -0800
committerJosh Rosen <joshrosen@databricks.com>2015-02-22 22:09:06 -0800
commite4f9d03d728bc6fbfb6ebc7d15b4ba328f98f3dc (patch)
tree927ea84a1d07d8ff59c38e047d02cd44d755232f /core
parent275b1bef897d775f1f7743378ca3e09e36160136 (diff)
downloadspark-e4f9d03d728bc6fbfb6ebc7d15b4ba328f98f3dc.tar.gz
spark-e4f9d03d728bc6fbfb6ebc7d15b4ba328f98f3dc.tar.bz2
spark-e4f9d03d728bc6fbfb6ebc7d15b4ba328f98f3dc.zip
[SPARK-911] allow efficient queries for a range if RDD is partitioned wi...
...th RangePartitioner Author: Aaron Josephs <ajoseph4@binghamton.edu> Closes #1381 from aaronjosephs/PLAT-911 and squashes the following commits: e30ade5 [Aaron Josephs] [SPARK-911] allow efficient queries for a range if RDD is partitioned with RangePartitioner
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala23
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala28
2 files changed, 51 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
index 144f679a59..6fdfdb734d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
@@ -75,4 +75,27 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
new ShuffledRDD[K, V, V](self, partitioner).setKeyOrdering(ordering)
}
+ /**
+ * Returns an RDD containing only the elements in the the inclusive range `lower` to `upper`.
+ * If the RDD has been partitioned using a `RangePartitioner`, then this operation can be
+ * performed efficiently by only scanning the partitions that might contain matching elements.
+ * Otherwise, a standard `filter` is applied to all partitions.
+ */
+ def filterByRange(lower: K, upper: K): RDD[P] = {
+
+ def inRange(k: K): Boolean = ordering.gteq(k, lower) && ordering.lteq(k, upper)
+
+ val rddToFilter: RDD[P] = self.partitioner match {
+ case Some(rp: RangePartitioner[K, V]) => {
+ val partitionIndicies = (rp.getPartition(lower), rp.getPartition(upper)) match {
+ case (l, u) => Math.min(l, u) to Math.max(l, u)
+ }
+ PartitionPruningRDD.create(self, partitionIndicies.contains)
+ }
+ case _ =>
+ self
+ }
+ rddToFilter.filter { case (k, v) => inRange(k) }
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala
index a40f2ffeff..64b1c24c47 100644
--- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala
@@ -119,5 +119,33 @@ class SortingSuite extends FunSuite with SharedSparkContext with Matchers with L
partitions(1).last should be > partitions(2).head
partitions(2).last should be > partitions(3).head
}
+
+ test("get a range of elements in a sorted RDD that is on one partition") {
+ val pairArr = (1 to 1000).map(x => (x, x)).toArray
+ val sorted = sc.parallelize(pairArr, 10).sortByKey()
+ val range = sorted.filterByRange(20, 40).collect()
+ assert((20 to 40).toArray === range.map(_._1))
+ }
+
+ test("get a range of elements over multiple partitions in a descendingly sorted RDD") {
+ val pairArr = (1000 to 1 by -1).map(x => (x, x)).toArray
+ val sorted = sc.parallelize(pairArr, 10).sortByKey(false)
+ val range = sorted.filterByRange(200, 800).collect()
+ assert((800 to 200 by -1).toArray === range.map(_._1))
+ }
+
+ test("get a range of elements in an array not partitioned by a range partitioner") {
+ val pairArr = util.Random.shuffle((1 to 1000).toList).map(x => (x, x))
+ val pairs = sc.parallelize(pairArr,10)
+ val range = pairs.filterByRange(200, 800).collect()
+ assert((800 to 200 by -1).toArray.sorted === range.map(_._1).sorted)
+ }
+
+ test("get a range of elements over multiple partitions but not taking up full partitions") {
+ val pairArr = (1000 to 1 by -1).map(x => (x, x)).toArray
+ val sorted = sc.parallelize(pairArr, 10).sortByKey(false)
+ val range = sorted.filterByRange(250, 850).collect()
+ assert((850 to 250 by -1).toArray === range.map(_._1))
+ }
}