summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/library/scala/collection/immutable/Stream.scala79
-rw-r--r--test/junit/scala/collection/immutable/StreamTest.scala65
2 files changed, 92 insertions, 52 deletions
diff --git a/src/library/scala/collection/immutable/Stream.scala b/src/library/scala/collection/immutable/Stream.scala
index b819302460..8be0b2fee2 100644
--- a/src/library/scala/collection/immutable/Stream.scala
+++ b/src/library/scala/collection/immutable/Stream.scala
@@ -524,57 +524,9 @@ self =>
*/
override def filter(p: A => Boolean): Stream[A] = filterImpl(p, isFlipped = false) // This override is only left in 2.11 because of binary compatibility, see PR #3925
- override final def withFilter(p: A => Boolean): StreamWithFilter = new StreamWithFilter(p)
-
- /** A lazier implementation of WithFilter than TraversableLike's.
- */
- final class StreamWithFilter(p: A => Boolean) extends WithFilter(p) {
-
- override def map[B, That](f: A => B)(implicit bf: CanBuildFrom[Stream[A], B, That]): That = {
- def tailMap(coll: Stream[A]): Stream[B] = {
- var head: A = null.asInstanceOf[A]
- var tail: Stream[A] = coll
- while (true) {
- if (tail.isEmpty)
- return Stream.Empty
- head = tail.head
- tail = tail.tail
- if (p(head))
- return cons(f(head), tailMap(tail))
- }
- throw new RuntimeException()
- }
-
- if (isStreamBuilder(bf)) asThat(tailMap(Stream.this))
- else super.map(f)(bf)
- }
-
- override def flatMap[B, That](f: A => GenTraversableOnce[B])(implicit bf: CanBuildFrom[Stream[A], B, That]): That = {
- def tailFlatMap(coll: Stream[A]): Stream[B] = {
- var head: A = null.asInstanceOf[A]
- var tail: Stream[A] = coll
- while (true) {
- if (tail.isEmpty)
- return Stream.Empty
- head = tail.head
- tail = tail.tail
- if (p(head))
- return f(head).toStream append tailFlatMap(tail)
- }
- throw new RuntimeException()
- }
-
- if (isStreamBuilder(bf)) asThat(tailFlatMap(Stream.this))
- else super.flatMap(f)(bf)
- }
-
- override def foreach[B](f: A => B) =
- for (x <- self)
- if (p(x)) f(x)
-
- override def withFilter(q: A => Boolean): StreamWithFilter =
- new StreamWithFilter(x => p(x) && q(x))
- }
+ /** A FilterMonadic which allows GC of the head of stream during processing */
+ @noinline // Workaround SI-9137, see https://github.com/scala/scala/pull/4284#issuecomment-73180791
+ override final def withFilter(p: A => Boolean): FilterMonadic[A, Stream[A]] = new Stream.StreamWithFilter(this, p)
/** A lazier Iterator than LinearSeqLike's. */
override def iterator: Iterator[A] = new StreamIterator(self)
@@ -1293,6 +1245,29 @@ object Stream extends SeqFactory[Stream] {
private[immutable] def collectedTail[A, B, That](head: B, stream: Stream[A], pf: PartialFunction[A, B], bf: CanBuildFrom[Stream[A], B, That]) = {
cons(head, stream.tail.collect(pf)(bf).asInstanceOf[Stream[B]])
}
-}
+ /** An implementation of `FilterMonadic` allowing GC of the filtered-out elements of
+ * the `Stream` as it is processed.
+ *
+ * Because this is not an inner class of `Stream` with a reference to the original
+ * head, it is now possible for GC to collect any leading and filtered-out elements
+ * which do not satisfy the filter, while the tail is still processing (see SI-8990).
+ */
+ private[immutable] final class StreamWithFilter[A](sl: => Stream[A], p: A => Boolean) extends FilterMonadic[A, Stream[A]] {
+ private var s = sl // set to null to allow GC after filtered
+ private lazy val filtered = { val f = s filter p; s = null; f } // don't set to null if throw during filter
+
+ def map[B, That](f: A => B)(implicit bf: CanBuildFrom[Stream[A], B, That]): That =
+ filtered map f
+ def flatMap[B, That](f: A => scala.collection.GenTraversableOnce[B])(implicit bf: CanBuildFrom[Stream[A], B, That]): That =
+ filtered flatMap f
+
+ def foreach[U](f: A => U): Unit =
+ filtered foreach f
+
+ def withFilter(q: A => Boolean): FilterMonadic[A, Stream[A]] =
+ new StreamWithFilter[A](filtered, q)
+ }
+
+}
diff --git a/test/junit/scala/collection/immutable/StreamTest.scala b/test/junit/scala/collection/immutable/StreamTest.scala
index 6dc1c79a48..36732fef2c 100644
--- a/test/junit/scala/collection/immutable/StreamTest.scala
+++ b/test/junit/scala/collection/immutable/StreamTest.scala
@@ -5,6 +5,9 @@ import org.junit.runners.JUnit4
import org.junit.Test
import org.junit.Assert._
+import scala.ref.WeakReference
+import scala.util.Try
+
@RunWith(classOf[JUnit4])
class StreamTest {
@@ -15,4 +18,66 @@ class StreamTest {
assertTrue(Stream(1,2,3,4,5).filter(_ < 4) == Seq(1,2,3))
assertTrue(Stream(1,2,3,4,5).filterNot(_ > 4) == Seq(1,2,3,4))
}
+
+ /** Test helper to verify that the given Stream operation allows
+ * GC of the head during processing of the tail.
+ */
+ def assertStreamOpAllowsGC(op: (=> Stream[Int], Int => Unit) => Any, f: Int => Unit): Unit = {
+ val msgSuccessGC = "GC success"
+ val msgFailureGC = "GC failure"
+
+ // A stream of 500 elements at most. We will test that the head can be collected
+ // while processing the tail. After each element we will GC and wait 10 ms, so a
+ // failure to collect will take roughly 5 seconds.
+ val ref = WeakReference( Stream.from(1).take(500) )
+
+ def gcAndThrowIfCollected(n: Int): Unit = {
+ System.gc() // try to GC
+ Thread.sleep(10) // give it 10 ms
+ if (ref.get.isEmpty) throw new RuntimeException(msgSuccessGC) // we're done if head collected
+ f(n)
+ }
+
+ val res = Try { op(ref(), gcAndThrowIfCollected) }.failed // success is indicated by an
+ val msg = res.map(_.getMessage).getOrElse(msgFailureGC) // exception with expected message
+ // failure is indicated by no
+ assertTrue(msg == msgSuccessGC) // exception, or one with different message
+ }
+
+ @Test
+ def foreach_allows_GC() {
+ assertStreamOpAllowsGC(_.foreach(_), _ => ())
+ }
+
+ @Test
+ def filter_all_foreach_allows_GC() {
+ assertStreamOpAllowsGC(_.filter(_ => true).foreach(_), _ => ())
+ }
+
+ @Test // SI-8990
+ def withFilter_after_first_foreach_allows_GC: Unit = {
+ assertStreamOpAllowsGC(_.withFilter(_ > 1).foreach(_), _ => ())
+ }
+
+ @Test // SI-8990
+ def withFilter_after_first_withFilter_foreach_allows_GC: Unit = {
+ assertStreamOpAllowsGC(_.withFilter(_ > 1).withFilter(_ < 100).foreach(_), _ => ())
+ }
+
+ @Test // SI-8990
+ def withFilter_can_retry_after_exception_thrown_in_filter: Unit = {
+ // use mutable state to control an intermittent failure in filtering the Stream
+ var shouldThrow = true
+
+ val wf = Stream.from(1).take(10).withFilter { n =>
+ if (shouldThrow && n == 5) throw new RuntimeException("n == 5") else n > 5
+ }
+
+ assertTrue( Try { wf.map(identity) }.isFailure ) // throws on n == 5
+
+ shouldThrow = false // won't throw next time
+
+ assertTrue( wf.map(identity).length == 5 ) // success instead of NPE
+ }
+
}