diff options
Diffstat (limited to 'core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala')
-rw-r--r-- | core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala | 62 |
1 files changed, 62 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala new file mode 100644 index 0000000000..0b51c23f7b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.Serializable +import java.util.{PriorityQueue => JPriorityQueue} +import scala.collection.generic.Growable +import scala.collection.JavaConverters._ + +/** + * Bounded priority queue. This class wraps the original PriorityQueue + * class and modifies it such that only the top K elements are retained. + * The top K elements are defined by an implicit Ordering[A]. + */ +class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A]) + extends Iterable[A] with Growable[A] with Serializable { + + private val underlying = new JPriorityQueue[A](maxSize, ord) + + override def iterator: Iterator[A] = underlying.iterator.asScala + + override def ++=(xs: TraversableOnce[A]): this.type = { + xs.foreach { this += _ } + this + } + + override def +=(elem: A): this.type = { + if (size < maxSize) underlying.offer(elem) + else maybeReplaceLowest(elem) + this + } + + override def +=(elem1: A, elem2: A, elems: A*): this.type = { + this += elem1 += elem2 ++= elems + } + + override def clear() { underlying.clear() } + + private def maybeReplaceLowest(a: A): Boolean = { + val head = underlying.peek() + if (head != null && ord.gt(a, head)) { + underlying.poll() + underlying.offer(a) + } else false + } +} + |