aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/spark/RDD.scala24
-rw-r--r--core/src/main/scala/spark/util/BoundedPriorityQueue.scala48
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala19
3 files changed, 91 insertions, 0 deletions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index e6c0438d76..ec5e5e2433 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -35,6 +35,7 @@ import spark.rdd.ZippedPartitionsRDD2
import spark.rdd.ZippedPartitionsRDD3
import spark.rdd.ZippedPartitionsRDD4
import spark.storage.StorageLevel
+import spark.util.BoundedPriorityQueue
import SparkContext._
@@ -723,6 +724,29 @@ abstract class RDD[T: ClassManifest](
}
/**
+ * Returns the top K elements from this RDD as defined by
+ * the specified implicit Ordering[T].
+ * @param num the number of top elements to return
+ * @param ord the implicit ordering for T
+ * @return an array of top elements
+ */
+ def top(num: Int)(implicit ord: Ordering[T]): Array[T] = {
+ val topK = mapPartitions { items =>
+ val queue = new BoundedPriorityQueue[T](num)
+ queue ++= items
+ Iterator(queue)
+ }.reduce { (queue1, queue2) =>
+ queue1 ++= queue2
+ queue1
+ }
+
+ val builder = Array.newBuilder[T]
+ builder.sizeHint(topK.size)
+ builder ++= topK
+ builder.result()
+ }
+
+ /**
* Save this RDD as a text file, using string representations of elements.
*/
def saveAsTextFile(path: String) {
diff --git a/core/src/main/scala/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala
new file mode 100644
index 0000000000..53ee95a02e
--- /dev/null
+++ b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala
@@ -0,0 +1,48 @@
+package spark.util
+
+import java.util.{PriorityQueue => JPriorityQueue}
+import scala.collection.generic.Growable
+
+/**
+ * Bounded priority queue. This class modifies the original PriorityQueue's
+ * add/offer methods such that only the top K elements are retained. The top
+ * K elements are defined by an implicit Ordering[A].
+ */
+class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A], mf: ClassManifest[A])
+ extends JPriorityQueue[A](maxSize, ord) with Growable[A] {
+
+ override def offer(a: A): Boolean = {
+ if (size < maxSize) super.offer(a)
+ else maybeReplaceLowest(a)
+ }
+
+ override def add(a: A): Boolean = offer(a)
+
+ override def ++=(xs: TraversableOnce[A]): this.type = {
+ xs.foreach(add)
+ this
+ }
+
+ override def +=(elem: A): this.type = {
+ add(elem)
+ this
+ }
+
+ override def +=(elem1: A, elem2: A, elems: A*): this.type = {
+ this += elem1 += elem2 ++= elems
+ }
+
+ private def maybeReplaceLowest(a: A): Boolean = {
+ val head = peek()
+ if (head != null && ord.gt(a, head)) {
+ poll()
+ super.offer(a)
+ } else false
+ }
+}
+
+object BoundedPriorityQueue {
+ import scala.collection.JavaConverters._
+ implicit def asIterable[A](queue: BoundedPriorityQueue[A]): Iterable[A] = queue.asScala
+}
+
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index 3f69e99780..67f3332d44 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -317,4 +317,23 @@ class RDDSuite extends FunSuite with LocalSparkContext {
assert(sample.size === checkSample.size)
for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
}
+
+ test("top with predefined ordering") {
+ sc = new SparkContext("local", "test")
+ val nums = Array.range(1, 100000)
+ val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2)
+ val topK = ints.top(5)
+ assert(topK.size === 5)
+ assert(topK.sorted === nums.sorted.takeRight(5))
+ }
+
+ test("top with custom ordering") {
+ sc = new SparkContext("local", "test")
+ val words = Vector("a", "b", "c", "d")
+ implicit val ord = implicitly[Ordering[String]].reverse
+ val rdd = sc.makeRDD(words, 2)
+ val topK = rdd.top(2)
+ assert(topK.size === 2)
+ assert(topK.sorted === Array("b", "a"))
+ }
}