From 8a537b7d7da03833946a6a2f4461da2080363c88 Mon Sep 17 00:00:00 2001 From: Paul Phillips Date: Mon, 29 Oct 2012 17:17:43 -0700 Subject: Fix SI-6584, Stream#distinct uses too much memory. Nesting recursive calls in Stream is always a dicey business. --- src/library/scala/collection/immutable/Stream.scala | 13 ++++++++++--- test/files/run/t6584.check | 8 ++++++++ test/files/run/t6584.scala | 16 ++++++++++++++++ 3 files changed, 34 insertions(+), 3 deletions(-) create mode 100644 test/files/run/t6584.check create mode 100644 test/files/run/t6584.scala diff --git a/src/library/scala/collection/immutable/Stream.scala b/src/library/scala/collection/immutable/Stream.scala index 461a375317..78c4d76eda 100644 --- a/src/library/scala/collection/immutable/Stream.scala +++ b/src/library/scala/collection/immutable/Stream.scala @@ -841,9 +841,16 @@ self => * // produces: "1, 2, 3, 4, 5, 6" * }}} */ - override def distinct: Stream[A] = - if (isEmpty) this - else cons(head, tail.filter(head != _).distinct) + override def distinct: Stream[A] = { + // This should use max memory proportional to N, whereas + // recursively calling distinct on the tail is N^2. + def loop(seen: Set[A], rest: Stream[A]): Stream[A] = { + if (rest.isEmpty) rest + else if (seen(rest.head)) loop(seen, rest.tail) + else cons(rest.head, loop(seen + rest.head, rest.tail)) + } + loop(Set(), this) + } /** Returns a new sequence of given length containing the elements of this * sequence followed by zero or more occurrences of given elements. diff --git a/test/files/run/t6584.check b/test/files/run/t6584.check new file mode 100644 index 0000000000..35c8688751 --- /dev/null +++ b/test/files/run/t6584.check @@ -0,0 +1,8 @@ +Array: 102400 +Vector: 102400 +List: 102400 +Stream: 102400 +Array: 102400 +Vector: 102400 +List: 102400 +Stream: 102400 diff --git a/test/files/run/t6584.scala b/test/files/run/t6584.scala new file mode 100644 index 0000000000..24c236ef35 --- /dev/null +++ b/test/files/run/t6584.scala @@ -0,0 +1,16 @@ +object Test { + def main(args: Array[String]): Unit = { + val size = 100 * 1024 + val doubled = (1 to size) ++ (1 to size) + + println("Array: " + Array.tabulate(size)(x => x).distinct.size) + println("Vector: " + Vector.tabulate(size)(x => x).distinct.size) + println("List: " + List.tabulate(size)(x => x).distinct.size) + println("Stream: " + Stream.tabulate(size)(x => x).distinct.size) + + println("Array: " + doubled.toArray.distinct.size) + println("Vector: " + doubled.toVector.distinct.size) + println("List: " + doubled.toList.distinct.size) + println("Stream: " + doubled.toStream.distinct.size) + } +} -- cgit v1.2.3 From 4e4060f4faee791759417f1a598322e90623464d Mon Sep 17 00:00:00 2001 From: Paul Phillips Date: Mon, 29 Oct 2012 20:20:35 -0700 Subject: Implementation of Stream#dropRight. "There's nothing we can do about dropRight," you say? Oh but there is. --- .../scala/collection/immutable/Stream.scala | 32 ++++++++++++++++++---- test/files/run/streams.check | 1 + test/files/run/streams.scala | 5 +++- 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/src/library/scala/collection/immutable/Stream.scala b/src/library/scala/collection/immutable/Stream.scala index 78c4d76eda..5566806c55 100644 --- a/src/library/scala/collection/immutable/Stream.scala +++ b/src/library/scala/collection/immutable/Stream.scala @@ -181,6 +181,7 @@ import scala.language.implicitConversions * @define coll stream * @define orderDependent * @define orderDependentFold + * @define willTerminateInf Note: lazily evaluated; will terminate for infinite-sized collections. */ abstract class Stream[+A] extends AbstractSeq[A] with LinearSeq[A] @@ -286,9 +287,8 @@ self => len } - /** It's an imperfect world, but at least we can bottle up the - * imperfection in a capsule. - */ + // It's an imperfect world, but at least we can bottle up the + // imperfection in a capsule. @inline private def asThat[That](x: AnyRef): That = x.asInstanceOf[That] @inline private def asStream[B](x: AnyRef): Stream[B] = x.asInstanceOf[Stream[B]] @inline private def isStreamBuilder[B, That](bf: CanBuildFrom[Stream[A], B, That]) = @@ -725,10 +725,15 @@ self => * // produces: "5, 6, 7, 8, 9" * }}} */ - override def take(n: Int): Stream[A] = + override def take(n: Int): Stream[A] = ( + // Note that the n == 1 condition appears redundant but is not. + // It prevents "tail" from being referenced (and its head being evaluated) + // when obtaining the last element of the result. Such are the challenges + // of working with a lazy-but-not-really sequence. if (n <= 0 || isEmpty) Stream.empty else if (n == 1) cons(head, Stream.empty) else cons(head, tail take n-1) + ) @tailrec final override def drop(n: Int): Stream[A] = if (n <= 0 || isEmpty) this @@ -784,8 +789,23 @@ self => these } - // there's nothing we can do about dropRight, so we just keep the definition - // in LinearSeq + /** + * @inheritdoc + * $willTerminateInf + */ + override def dropRight(n: Int): Stream[A] = { + // We make dropRight work for possibly infinite streams by carrying + // a buffer of the dropped size. As long as the buffer is full and the + // rest is non-empty, we can feed elements off the buffer head. When + // the rest becomes empty, the full buffer is the dropped elements. + def advance(stub0: List[A], stub1: List[A], rest: Stream[A]): Stream[A] = { + if (rest.isEmpty) Stream.empty + else if (stub0.isEmpty) advance(stub1.reverse, Nil, rest) + else cons(stub0.head, advance(stub0.tail, rest.head :: stub1, rest.tail)) + } + if (n <= 0) this + else advance((this take n).toList, Nil, this drop n) + } /** Returns the longest prefix of this `Stream` whose elements satisfy the * predicate `p`. diff --git a/test/files/run/streams.check b/test/files/run/streams.check index 7f894052d9..032057d4a1 100644 --- a/test/files/run/streams.check +++ b/test/files/run/streams.check @@ -23,3 +23,4 @@ Stream(100001, ?) true true 705082704 +6 diff --git a/test/files/run/streams.scala b/test/files/run/streams.scala index 51b4e5d76c..dc5d0204ac 100644 --- a/test/files/run/streams.scala +++ b/test/files/run/streams.scala @@ -29,7 +29,7 @@ object Test extends App { def powers(x: Int) = if ((x&(x-1)) == 0) Some(x) else None println(s3.flatMap(powers).reverse.head) - // large enough to generate StackOverflows (on most systems) + // large enough to generate StackOverflows (on most systems) // unless the following methods are tail call optimized. val size = 100000 @@ -43,4 +43,7 @@ object Test extends App { println(Stream.from(1).take(size).foldLeft(0)(_ + _)) val arr = new Array[Int](size) Stream.from(1).take(size).copyToArray(arr, 0) + + // dropRight terminates + println(Stream from 1 dropRight 1000 take 3 sum) } -- cgit v1.2.3