diff options
-rw-r--r-- | core/src/main/scala/spark/RDD.scala | 24 | ||||
-rw-r--r-- | core/src/main/scala/spark/util/BoundedPriorityQueue.scala | 48 | ||||
-rw-r--r-- | core/src/test/scala/spark/RDDSuite.scala | 19 |
3 files changed, 91 insertions, 0 deletions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index e6c0438d76..ec5e5e2433 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -35,6 +35,7 @@ import spark.rdd.ZippedPartitionsRDD2 import spark.rdd.ZippedPartitionsRDD3 import spark.rdd.ZippedPartitionsRDD4 import spark.storage.StorageLevel +import spark.util.BoundedPriorityQueue import SparkContext._ @@ -723,6 +724,29 @@ abstract class RDD[T: ClassManifest]( } /** + * Returns the top K elements from this RDD as defined by + * the specified implicit Ordering[T]. + * @param num the number of top elements to return + * @param ord the implicit ordering for T + * @return an array of top elements + */ + def top(num: Int)(implicit ord: Ordering[T]): Array[T] = { + val topK = mapPartitions { items => + val queue = new BoundedPriorityQueue[T](num) + queue ++= items + Iterator(queue) + }.reduce { (queue1, queue2) => + queue1 ++= queue2 + queue1 + } + + val builder = Array.newBuilder[T] + builder.sizeHint(topK.size) + builder ++= topK + builder.result() + } + + /** * Save this RDD as a text file, using string representations of elements. */ def saveAsTextFile(path: String) { diff --git a/core/src/main/scala/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala new file mode 100644 index 0000000000..53ee95a02e --- /dev/null +++ b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala @@ -0,0 +1,48 @@ +package spark.util + +import java.util.{PriorityQueue => JPriorityQueue} +import scala.collection.generic.Growable + +/** + * Bounded priority queue. This class modifies the original PriorityQueue's + * add/offer methods such that only the top K elements are retained. The top + * K elements are defined by an implicit Ordering[A]. + */ +class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A], mf: ClassManifest[A]) + extends JPriorityQueue[A](maxSize, ord) with Growable[A] { + + override def offer(a: A): Boolean = { + if (size < maxSize) super.offer(a) + else maybeReplaceLowest(a) + } + + override def add(a: A): Boolean = offer(a) + + override def ++=(xs: TraversableOnce[A]): this.type = { + xs.foreach(add) + this + } + + override def +=(elem: A): this.type = { + add(elem) + this + } + + override def +=(elem1: A, elem2: A, elems: A*): this.type = { + this += elem1 += elem2 ++= elems + } + + private def maybeReplaceLowest(a: A): Boolean = { + val head = peek() + if (head != null && ord.gt(a, head)) { + poll() + super.offer(a) + } else false + } +} + +object BoundedPriorityQueue { + import scala.collection.JavaConverters._ + implicit def asIterable[A](queue: BoundedPriorityQueue[A]): Iterable[A] = queue.asScala +} + diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 3f69e99780..67f3332d44 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -317,4 +317,23 @@ class RDDSuite extends FunSuite with LocalSparkContext { assert(sample.size === checkSample.size) for (i <- 0 until sample.size) assert(sample(i) === checkSample(i)) } + + test("top with predefined ordering") { + sc = new SparkContext("local", "test") + val nums = Array.range(1, 100000) + val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2) + val topK = ints.top(5) + assert(topK.size === 5) + assert(topK.sorted === nums.sorted.takeRight(5)) + } + + test("top with custom ordering") { + sc = new SparkContext("local", "test") + val words = Vector("a", "b", "c", "d") + implicit val ord = implicitly[Ordering[String]].reverse + val rdd = sc.makeRDD(words, 2) + val topK = rdd.top(2) + assert(topK.size === 2) + assert(topK.sorted === Array("b", "a")) + } } |