diff options
author | Adriaan Moors <adriaan.moors@typesafe.com> | 2012-11-21 13:27:43 -0800 |
---|---|---|
committer | Adriaan Moors <adriaan.moors@typesafe.com> | 2012-11-21 13:27:43 -0800 |
commit | 889ceade520ae5d2d1485edf2826696fa91a0e91 (patch) | |
tree | d49195c14a87d4bfe46a3e6f438757148a2efcab | |
parent | 1cfb36317834f9bca0c3ce94e92590f7b4ace3b7 (diff) | |
parent | bc3dda2b0222d3b7cf3db491728b98f9b6110856 (diff) | |
download | scala-889ceade520ae5d2d1485edf2826696fa91a0e91.tar.gz scala-889ceade520ae5d2d1485edf2826696fa91a0e91.tar.bz2 scala-889ceade520ae5d2d1485edf2826696fa91a0e91.zip |
Merge pull request #1570 from retronym/ticket/6448
SI-6448 Collecting the spoils of PartialFun#runWith
-rw-r--r-- | src/library/scala/Option.scala | 2 | ||||
-rw-r--r-- | src/library/scala/collection/TraversableLike.scala | 2 | ||||
-rw-r--r-- | src/library/scala/collection/TraversableOnce.scala | 6 | ||||
-rw-r--r-- | src/library/scala/collection/immutable/Stream.scala | 13 | ||||
-rw-r--r-- | src/library/scala/collection/parallel/RemainsIterator.scala | 3 | ||||
-rw-r--r-- | src/library/scala/collection/parallel/mutable/ParArray.scala | 3 | ||||
-rw-r--r-- | test/files/run/t6448.check | 32 | ||||
-rw-r--r-- | test/files/run/t6448.scala | 61 |
8 files changed, 110 insertions, 12 deletions
diff --git a/src/library/scala/Option.scala b/src/library/scala/Option.scala index ace018fbf9..4b071166c7 100644 --- a/src/library/scala/Option.scala +++ b/src/library/scala/Option.scala @@ -256,7 +256,7 @@ sealed abstract class Option[+A] extends Product with Serializable { * value (if possible), or $none. */ @inline final def collect[B](pf: PartialFunction[A, B]): Option[B] = - if (!isEmpty && pf.isDefinedAt(this.get)) Some(pf(this.get)) else None + if (!isEmpty) pf.lift(this.get) else None /** Returns this $option if it is nonempty, * otherwise return the result of evaluating `alternative`. diff --git a/src/library/scala/collection/TraversableLike.scala b/src/library/scala/collection/TraversableLike.scala index fc7202b96b..5ea1faa0df 100644 --- a/src/library/scala/collection/TraversableLike.scala +++ b/src/library/scala/collection/TraversableLike.scala @@ -275,7 +275,7 @@ trait TraversableLike[+A, +Repr] extends Any def collect[B, That](pf: PartialFunction[A, B])(implicit bf: CanBuildFrom[Repr, B, That]): That = { val b = bf(repr) - for (x <- this) if (pf.isDefinedAt(x)) b += pf(x) + foreach(pf.runWith(b += _)) b.result } diff --git a/src/library/scala/collection/TraversableOnce.scala b/src/library/scala/collection/TraversableOnce.scala index d6936979e1..82cf1d1198 100644 --- a/src/library/scala/collection/TraversableOnce.scala +++ b/src/library/scala/collection/TraversableOnce.scala @@ -128,10 +128,8 @@ trait TraversableOnce[+A] extends Any with GenTraversableOnce[A] { * @example `Seq("a", 1, 5L).collectFirst({ case x: Int => x*10 }) = Some(10)` */ def collectFirst[B](pf: PartialFunction[A, B]): Option[B] = { - for (x <- self.toIterator) { // make sure to use an iterator or `seq` - if (pf isDefinedAt x) - return Some(pf(x)) - } + // make sure to use an iterator or `seq` + self.toIterator.foreach(pf.runWith(b => return Some(b))) None } diff --git a/src/library/scala/collection/immutable/Stream.scala b/src/library/scala/collection/immutable/Stream.scala index a5c6254aed..e520ddc27c 100644 --- a/src/library/scala/collection/immutable/Stream.scala +++ b/src/library/scala/collection/immutable/Stream.scala @@ -385,12 +385,17 @@ self => // 1) stackoverflows (could be achieved with tailrec, too) // 2) out of memory errors for big streams (`this` reference can be eliminated from the stack) var rest: Stream[A] = this - while (rest.nonEmpty && !pf.isDefinedAt(rest.head)) rest = rest.tail + + // Avoids calling both `pf.isDefined` and `pf.apply`. + var newHead: B = null.asInstanceOf[B] + val runWith = pf.runWith((b: B) => newHead = b) + + while (rest.nonEmpty && !runWith(rest.head)) rest = rest.tail // without the call to the companion object, a thunk is created for the tail of the new stream, // and the closure of the thunk will reference `this` if (rest.isEmpty) Stream.Empty.asInstanceOf[That] - else Stream.collectedTail(rest, pf, bf).asInstanceOf[That] + else Stream.collectedTail(newHead, rest, pf, bf).asInstanceOf[That] } } @@ -1170,8 +1175,8 @@ object Stream extends SeqFactory[Stream] { cons(stream.head, stream.tail filter p) } - private[immutable] def collectedTail[A, B, That](stream: Stream[A], pf: PartialFunction[A, B], bf: CanBuildFrom[Stream[A], B, That]) = { - cons(pf(stream.head), stream.tail.collect(pf)(bf).asInstanceOf[Stream[B]]) + private[immutable] def collectedTail[A, B, That](head: B, stream: Stream[A], pf: PartialFunction[A, B], bf: CanBuildFrom[Stream[A], B, That]) = { + cons(head, stream.tail.collect(pf)(bf).asInstanceOf[Stream[B]]) } } diff --git a/src/library/scala/collection/parallel/RemainsIterator.scala b/src/library/scala/collection/parallel/RemainsIterator.scala index 3150b0d763..732ebc3709 100644 --- a/src/library/scala/collection/parallel/RemainsIterator.scala +++ b/src/library/scala/collection/parallel/RemainsIterator.scala @@ -123,9 +123,10 @@ private[collection] trait AugmentedIterableIterator[+T] extends RemainsIterator[ def collect2combiner[S, That](pf: PartialFunction[T, S], cb: Combiner[S, That]): Combiner[S, That] = { //val cb = pbf(repr) + val runWith = pf.runWith(cb += _) while (hasNext) { val curr = next - if (pf.isDefinedAt(curr)) cb += pf(curr) + runWith(curr) } cb } diff --git a/src/library/scala/collection/parallel/mutable/ParArray.scala b/src/library/scala/collection/parallel/mutable/ParArray.scala index 33b2f0b29c..db4a59f6ba 100644 --- a/src/library/scala/collection/parallel/mutable/ParArray.scala +++ b/src/library/scala/collection/parallel/mutable/ParArray.scala @@ -405,9 +405,10 @@ self => private def collect2combiner_quick[S, That](pf: PartialFunction[T, S], a: Array[Any], cb: Builder[S, That], ntil: Int, from: Int) { var j = from + val runWith = pf.runWith(b => cb += b) while (j < ntil) { val curr = a(j).asInstanceOf[T] - if (pf.isDefinedAt(curr)) cb += pf(curr) + runWith(curr) j += 1 } } diff --git a/test/files/run/t6448.check b/test/files/run/t6448.check new file mode 100644 index 0000000000..9401568319 --- /dev/null +++ b/test/files/run/t6448.check @@ -0,0 +1,32 @@ + +=List.collect= +f(1) +f(2) +List(1) + +=List.collectFirst= +f(1) +Some(1) + +=Option.collect= +f(1) +Some(1) + +=Option.collect= +f(2) +None + +=Stream.collect= +f(1) +f(2) +List(1) + +=Stream.collectFirst= +f(1) +Some(1) + +=ParVector.collect= +(ParVector(1),2) + +=ParArray.collect= +(ParArray(1),2) diff --git a/test/files/run/t6448.scala b/test/files/run/t6448.scala new file mode 100644 index 0000000000..4d1528e500 --- /dev/null +++ b/test/files/run/t6448.scala @@ -0,0 +1,61 @@ +// Tests to show that various `collect` functions avoid calling +// both `PartialFunction#isDefinedAt` and `PartialFunction#apply`. +// +object Test { + def f(i: Int) = { println("f(" + i + ")"); true } + class Counter { + var count = 0 + def apply(i: Int) = synchronized {count += 1; true} + } + + def testing(label: String)(body: => Any) { + println(s"\n=$label=") + println(body) + } + + def main(args: Array[String]) { + testing("List.collect")(List(1, 2) collect { case x if f(x) && x < 2 => x}) + testing("List.collectFirst")(List(1, 2) collectFirst { case x if f(x) && x < 2 => x}) + testing("Option.collect")(Some(1) collect { case x if f(x) && x < 2 => x}) + testing("Option.collect")(Some(2) collect { case x if f(x) && x < 2 => x}) + testing("Stream.collect")((Stream(1, 2).collect { case x if f(x) && x < 2 => x}).toList) + testing("Stream.collectFirst")(Stream.continually(1) collectFirst { case x if f(x) && x < 2 => x}) + + import collection.parallel.ParIterable + import collection.parallel.immutable.ParVector + import collection.parallel.mutable.ParArray + testing("ParVector.collect") { + val counter = new Counter() + (ParVector(1, 2) collect { case x if counter(x) && x < 2 => x}, counter.synchronized(counter.count)) + } + + testing("ParArray.collect") { + val counter = new Counter() + (ParArray(1, 2) collect { case x if counter(x) && x < 2 => x}, counter.synchronized(counter.count)) + } + + object PendingTests { + testing("Iterator.collect")((Iterator(1, 2) collect { case x if f(x) && x < 2 => x}).toList) + + testing("List.view.collect")((List(1, 2).view collect { case x if f(x) && x < 2 => x}).force) + + // This would do the trick in Future.collect, but I haven't added this yet as there is a tradeoff + // with extra allocations to consider. + // + // pf.lift(v) match { + // case Some(x) => p success x + // case None => fail(v) + // } + testing("Future.collect") { + import concurrent.ExecutionContext.Implicits.global + import concurrent.Await + import concurrent.duration.Duration + val result = concurrent.future(1) collect { case x if f(x) => x} + Await.result(result, Duration.Inf) + } + + // TODO Future.{onSuccess, onFailure, recoverWith, andThen} + } + + } +} |