diff options
-rw-r--r-- | src/library/scala/collection/immutable/Stream.scala | 79 | ||||
-rw-r--r-- | test/junit/scala/collection/immutable/StreamTest.scala | 65 |
2 files changed, 92 insertions, 52 deletions
diff --git a/src/library/scala/collection/immutable/Stream.scala b/src/library/scala/collection/immutable/Stream.scala index b819302460..8be0b2fee2 100644 --- a/src/library/scala/collection/immutable/Stream.scala +++ b/src/library/scala/collection/immutable/Stream.scala @@ -524,57 +524,9 @@ self => */ override def filter(p: A => Boolean): Stream[A] = filterImpl(p, isFlipped = false) // This override is only left in 2.11 because of binary compatibility, see PR #3925 - override final def withFilter(p: A => Boolean): StreamWithFilter = new StreamWithFilter(p) - - /** A lazier implementation of WithFilter than TraversableLike's. - */ - final class StreamWithFilter(p: A => Boolean) extends WithFilter(p) { - - override def map[B, That](f: A => B)(implicit bf: CanBuildFrom[Stream[A], B, That]): That = { - def tailMap(coll: Stream[A]): Stream[B] = { - var head: A = null.asInstanceOf[A] - var tail: Stream[A] = coll - while (true) { - if (tail.isEmpty) - return Stream.Empty - head = tail.head - tail = tail.tail - if (p(head)) - return cons(f(head), tailMap(tail)) - } - throw new RuntimeException() - } - - if (isStreamBuilder(bf)) asThat(tailMap(Stream.this)) - else super.map(f)(bf) - } - - override def flatMap[B, That](f: A => GenTraversableOnce[B])(implicit bf: CanBuildFrom[Stream[A], B, That]): That = { - def tailFlatMap(coll: Stream[A]): Stream[B] = { - var head: A = null.asInstanceOf[A] - var tail: Stream[A] = coll - while (true) { - if (tail.isEmpty) - return Stream.Empty - head = tail.head - tail = tail.tail - if (p(head)) - return f(head).toStream append tailFlatMap(tail) - } - throw new RuntimeException() - } - - if (isStreamBuilder(bf)) asThat(tailFlatMap(Stream.this)) - else super.flatMap(f)(bf) - } - - override def foreach[B](f: A => B) = - for (x <- self) - if (p(x)) f(x) - - override def withFilter(q: A => Boolean): StreamWithFilter = - new StreamWithFilter(x => p(x) && q(x)) - } + /** A FilterMonadic which allows GC of the head of stream during processing */ + @noinline // Workaround SI-9137, see https://github.com/scala/scala/pull/4284#issuecomment-73180791 + override final def withFilter(p: A => Boolean): FilterMonadic[A, Stream[A]] = new Stream.StreamWithFilter(this, p) /** A lazier Iterator than LinearSeqLike's. */ override def iterator: Iterator[A] = new StreamIterator(self) @@ -1293,6 +1245,29 @@ object Stream extends SeqFactory[Stream] { private[immutable] def collectedTail[A, B, That](head: B, stream: Stream[A], pf: PartialFunction[A, B], bf: CanBuildFrom[Stream[A], B, That]) = { cons(head, stream.tail.collect(pf)(bf).asInstanceOf[Stream[B]]) } -} + /** An implementation of `FilterMonadic` allowing GC of the filtered-out elements of + * the `Stream` as it is processed. + * + * Because this is not an inner class of `Stream` with a reference to the original + * head, it is now possible for GC to collect any leading and filtered-out elements + * which do not satisfy the filter, while the tail is still processing (see SI-8990). + */ + private[immutable] final class StreamWithFilter[A](sl: => Stream[A], p: A => Boolean) extends FilterMonadic[A, Stream[A]] { + private var s = sl // set to null to allow GC after filtered + private lazy val filtered = { val f = s filter p; s = null; f } // don't set to null if throw during filter + + def map[B, That](f: A => B)(implicit bf: CanBuildFrom[Stream[A], B, That]): That = + filtered map f + def flatMap[B, That](f: A => scala.collection.GenTraversableOnce[B])(implicit bf: CanBuildFrom[Stream[A], B, That]): That = + filtered flatMap f + + def foreach[U](f: A => U): Unit = + filtered foreach f + + def withFilter(q: A => Boolean): FilterMonadic[A, Stream[A]] = + new StreamWithFilter[A](filtered, q) + } + +} diff --git a/test/junit/scala/collection/immutable/StreamTest.scala b/test/junit/scala/collection/immutable/StreamTest.scala index 6dc1c79a48..36732fef2c 100644 --- a/test/junit/scala/collection/immutable/StreamTest.scala +++ b/test/junit/scala/collection/immutable/StreamTest.scala @@ -5,6 +5,9 @@ import org.junit.runners.JUnit4 import org.junit.Test import org.junit.Assert._ +import scala.ref.WeakReference +import scala.util.Try + @RunWith(classOf[JUnit4]) class StreamTest { @@ -15,4 +18,66 @@ class StreamTest { assertTrue(Stream(1,2,3,4,5).filter(_ < 4) == Seq(1,2,3)) assertTrue(Stream(1,2,3,4,5).filterNot(_ > 4) == Seq(1,2,3,4)) } + + /** Test helper to verify that the given Stream operation allows + * GC of the head during processing of the tail. + */ + def assertStreamOpAllowsGC(op: (=> Stream[Int], Int => Unit) => Any, f: Int => Unit): Unit = { + val msgSuccessGC = "GC success" + val msgFailureGC = "GC failure" + + // A stream of 500 elements at most. We will test that the head can be collected + // while processing the tail. After each element we will GC and wait 10 ms, so a + // failure to collect will take roughly 5 seconds. + val ref = WeakReference( Stream.from(1).take(500) ) + + def gcAndThrowIfCollected(n: Int): Unit = { + System.gc() // try to GC + Thread.sleep(10) // give it 10 ms + if (ref.get.isEmpty) throw new RuntimeException(msgSuccessGC) // we're done if head collected + f(n) + } + + val res = Try { op(ref(), gcAndThrowIfCollected) }.failed // success is indicated by an + val msg = res.map(_.getMessage).getOrElse(msgFailureGC) // exception with expected message + // failure is indicated by no + assertTrue(msg == msgSuccessGC) // exception, or one with different message + } + + @Test + def foreach_allows_GC() { + assertStreamOpAllowsGC(_.foreach(_), _ => ()) + } + + @Test + def filter_all_foreach_allows_GC() { + assertStreamOpAllowsGC(_.filter(_ => true).foreach(_), _ => ()) + } + + @Test // SI-8990 + def withFilter_after_first_foreach_allows_GC: Unit = { + assertStreamOpAllowsGC(_.withFilter(_ > 1).foreach(_), _ => ()) + } + + @Test // SI-8990 + def withFilter_after_first_withFilter_foreach_allows_GC: Unit = { + assertStreamOpAllowsGC(_.withFilter(_ > 1).withFilter(_ < 100).foreach(_), _ => ()) + } + + @Test // SI-8990 + def withFilter_can_retry_after_exception_thrown_in_filter: Unit = { + // use mutable state to control an intermittent failure in filtering the Stream + var shouldThrow = true + + val wf = Stream.from(1).take(10).withFilter { n => + if (shouldThrow && n == 5) throw new RuntimeException("n == 5") else n > 5 + } + + assertTrue( Try { wf.map(identity) }.isFailure ) // throws on n == 5 + + shouldThrow = false // won't throw next time + + assertTrue( wf.map(identity).length == 5 ) // success instead of NPE + } + } |