diff options
-rw-r--r-- | src/library/scala/collection/Iterator.scala | 29 | ||||
-rw-r--r-- | test/junit/scala/collection/IteratorTest.scala | 10 |
2 files changed, 22 insertions, 17 deletions
diff --git a/src/library/scala/collection/Iterator.scala b/src/library/scala/collection/Iterator.scala index 0783beac0f..0f6ae47e89 100644 --- a/src/library/scala/collection/Iterator.scala +++ b/src/library/scala/collection/Iterator.scala @@ -10,7 +10,7 @@ package scala package collection import mutable.ArrayBuffer -import scala.annotation.migration +import scala.annotation.{ migration, tailrec } import immutable.Stream import scala.collection.generic.CanBuildFrom import scala.annotation.unchecked.{ uncheckedVariance => uV } @@ -580,29 +580,24 @@ trait Iterator[+A] extends TraversableOnce[A] { def span(p: A => Boolean): (Iterator[A], Iterator[A]) = { val self = buffered - /* - * Giving a name to following iterator (as opposed to trailing) because - * anonymous class is represented as a structural type that trailing - * iterator is referring (the finish() method) and thus triggering - * handling of structural calls. It's not what's intended here. - */ + // Must be a named class to avoid structural call to finish from trailing iterator class Leading extends AbstractIterator[A] { - val lookahead = new mutable.Queue[A] - def advance() = { - self.hasNext && p(self.head) && { + private val lookahead = new mutable.Queue[A] + private var finished = false + private def advance() = !finished && { + if (self.hasNext && p(self.head)) { lookahead += self.next true + } else { + finished = true + false } } - def finish() = { - while (advance()) () - } + @tailrec final def finish(): Unit = if (advance()) finish() def hasNext = lookahead.nonEmpty || advance() def next() = { - if (lookahead.isEmpty) - advance() - - lookahead.dequeue() + if (!hasNext) empty.next() + else lookahead.dequeue() } } val leading = new Leading diff --git a/test/junit/scala/collection/IteratorTest.scala b/test/junit/scala/collection/IteratorTest.scala index d5389afd0c..1c1e50aed9 100644 --- a/test/junit/scala/collection/IteratorTest.scala +++ b/test/junit/scala/collection/IteratorTest.scala @@ -154,4 +154,14 @@ class IteratorTest { results += (Stream from 1).toIterator.drop(10).toStream.drop(10).toIterator.next() assertSameElements(List(1,1,21), results) } + // SI-9332 + @Test def spanExhaustsLeadingIterator(): Unit = { + def it = Iterator.iterate(0)(_ + 1).take(6) + val (x, y) = it.span(_ != 1) + val z = x.toList + assertEquals(1, z.size) + assertFalse(x.hasNext) + assertEquals(1, y.next) + assertFalse(x.hasNext) // was true, after advancing underlying iterator + } } |