From f7019903bf9c3f88f996253bb88bbd4264b28be6 Mon Sep 17 00:00:00 2001 From: Stefan Zeiger Date: Thu, 4 Feb 2016 18:47:52 +0100 Subject: Replace JoinIterator & improve ConcatIterator The new `ConcatIterator` requires only one extra lightweight wrapper object (cons cell) to be allocated compared to `JoinIterator`. All additional concatenations are then done in place with one cons cell per appended iterator. Running 1000000 iterations of the following benchmark for LHS recursion: ``` def lhs(n: Int) = (1 to n).foldLeft(Iterator.empty: Iterator[Int])((res, _) => res ++ Iterator(1)).sum ``` On 2.12.x before SI-9623 fix: ``` $ ../scala/build-sbt/quick/bin/scala -J-Xmx1024M -nc concatit.scala 1000000: 555ms 1000000: 344ms 1000000: 397ms 1000000: 309ms 1000000: 290ms 1000000: 283ms 1000000: 282ms 1000000: 281ms 1000000: 290ms 1000000: 279ms ``` With SI-9623 fix: ``` $ ../scala/build-sbt/quick/bin/scala -J-Xmx1024M -nc concatit.scala 1000000: 610ms 1000000: 324ms 1000000: 387ms 1000000: 315ms 1000000: 296ms 1000000: 300ms 1000000: 341ms 1000000: 294ms 1000000: 291ms 1000000: 281ms ``` With this version: ``` $ ../scala/build-sbt/quick/bin/scala -J-Xmx1024M -nc concatit.scala 1000000: 362ms 1000000: 162ms 1000000: 140ms 1000000: 150ms 1000000: 110ms 1000000: 57ms 1000000: 79ms 1000000: 109ms 1000000: 120ms 1000000: 49ms ``` And for RHS recursion: ``` def rhs(n: Int) = (1 to n).foldLeft(Iterator.empty: Iterator[Int])((res, _) => Iterator(1) ++ res).sum ``` On 2.12.x before SI-9623 fix: ``` StackOverflowError ``` With SI-9623 fix: ``` StackOverflowError ``` With this version: ``` $ ../scala/build-sbt/quick/bin/scala -J-Xmx1024M -nc concatit.scala 1000000: 3156ms 1000000: 1536ms 1000000: 1240ms 1000000: 1575ms 1000000: 439ms 1000000: 706ms 1000000: 1043ms 1000000: 1211ms 1000000: 515ms 1000000: 314ms ``` --- src/library/scala/collection/Iterator.scala | 90 +++++++++++++++-------------- 1 file changed, 46 insertions(+), 44 deletions(-) (limited to 'src/library') diff --git a/src/library/scala/collection/Iterator.scala b/src/library/scala/collection/Iterator.scala index 518bba6b6d..1426278954 100644 --- a/src/library/scala/collection/Iterator.scala +++ b/src/library/scala/collection/Iterator.scala @@ -11,6 +11,7 @@ package collection import mutable.ArrayBuffer import scala.annotation.{tailrec, migration} +import scala.annotation.unchecked.{uncheckedVariance => uV} import immutable.Stream /** The `Iterator` object provides various functions for creating specialized iterators. @@ -160,30 +161,49 @@ object Iterator { def next = elem } - /** Avoid stack overflows when applying ++ to lots of iterators by - * flattening the unevaluated iterators out into a vector of closures. + /** Creates an iterator to which other iterators can be appended efficiently. + * Nested ConcatIterators are merged to avoid blowing the stack. */ - private[scala] final class ConcatIterator[+A](private[this] var current: Iterator[A], initial: Vector[() => Iterator[A]]) extends Iterator[A] { - @deprecated def this(initial: Vector[() => Iterator[A]]) = this(Iterator.empty, initial) // for binary compatibility - private[this] var queue: Vector[() => Iterator[A]] = initial - private[this] var currentHasNextChecked = false + private final class ConcatIterator[+A](private var current: Iterator[A @uV]) extends Iterator[A] { + private var tail: ConcatIteratorCell[A @uV] = null + private var last: ConcatIteratorCell[A @uV] = null + private var currentHasNextChecked = false + // Advance current to the next non-empty iterator // current is set to null when all iterators are exhausted @tailrec private[this] def advance(): Boolean = { - if (queue.isEmpty) { + if (tail eq null) { current = null + last = null false } else { - current = queue.head() - queue = queue.tail - if (current.hasNext) { + current = tail.headIterator + tail = tail.tail + merge() + if (currentHasNextChecked) true + else if (current.hasNext) { currentHasNextChecked = true true } else advance() } } + + // If the current iterator is a ConcatIterator, merge it into this one + @tailrec + private[this] def merge(): Unit = + if (current.isInstanceOf[ConcatIterator[_]]) { + val c = current.asInstanceOf[ConcatIterator[A]] + current = c.current + currentHasNextChecked = c.currentHasNextChecked + if (c.tail ne null) { + c.last.tail = tail + tail = c.tail + } + merge() + } + def hasNext = if (currentHasNextChecked) true else if (current eq null) false @@ -191,47 +211,29 @@ object Iterator { currentHasNextChecked = true true } else advance() + def next() = if (hasNext) { currentHasNextChecked = false current.next() } else Iterator.empty.next() - override def ++[B >: A](that: => GenTraversableOnce[B]): Iterator[B] = - new ConcatIterator(current, queue :+ (() => that.toIterator)) - } - - private[scala] final class JoinIterator[+A](lhs: Iterator[A], that: => GenTraversableOnce[A]) extends Iterator[A] { - private[this] var state = 0 // 0: lhs not checked, 1: lhs has next, 2: switched to rhs - private[this] lazy val rhs: Iterator[A] = that.toIterator - def hasNext = state match { - case 0 => - if (lhs.hasNext) { - state = 1 - true - } else { - state = 2 - rhs.hasNext - } - case 1 => true - case _ => rhs.hasNext - } - def next() = state match { - case 0 => - if (lhs.hasNext) lhs.next() - else { - state = 2 - rhs.next() - } - case 1 => - state = 0 - lhs.next() - case _ => - rhs.next() + override def ++[B >: A](that: => GenTraversableOnce[B]): Iterator[B] = { + val c = new ConcatIteratorCell[B](that, null).asInstanceOf[ConcatIteratorCell[A]] + if(tail eq null) { + tail = c + last = c + } else { + last.tail = c + last = c + } + if(current eq null) current = Iterator.empty + this } + } - override def ++[B >: A](that: => GenTraversableOnce[B]) = - new ConcatIterator(this, Vector(() => that.toIterator)) + private[this] final class ConcatIteratorCell[A](head: => GenTraversableOnce[A], var tail: ConcatIteratorCell[A]) { + def headIterator: Iterator[A] = head.toIterator } /** Creates a delegating iterator capped by a limit count. Negative limit means unbounded. @@ -456,7 +458,7 @@ trait Iterator[+A] extends TraversableOnce[A] { * @usecase def ++(that: => Iterator[A]): Iterator[A] * @inheritdoc */ - def ++[B >: A](that: => GenTraversableOnce[B]): Iterator[B] = new Iterator.JoinIterator(self, that) + def ++[B >: A](that: => GenTraversableOnce[B]): Iterator[B] = new Iterator.ConcatIterator(self) ++ that /** Creates a new iterator by applying a function to all values produced by this iterator * and concatenating the results. -- cgit v1.2.3