summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdriaan Moors <adriaan.moors@typesafe.com>2012-11-21 13:27:43 -0800
committerAdriaan Moors <adriaan.moors@typesafe.com>2012-11-21 13:27:43 -0800
commit889ceade520ae5d2d1485edf2826696fa91a0e91 (patch)
treed49195c14a87d4bfe46a3e6f438757148a2efcab
parent1cfb36317834f9bca0c3ce94e92590f7b4ace3b7 (diff)
parentbc3dda2b0222d3b7cf3db491728b98f9b6110856 (diff)
downloadscala-889ceade520ae5d2d1485edf2826696fa91a0e91.tar.gz
scala-889ceade520ae5d2d1485edf2826696fa91a0e91.tar.bz2
scala-889ceade520ae5d2d1485edf2826696fa91a0e91.zip
Merge pull request #1570 from retronym/ticket/6448
SI-6448 Collecting the spoils of PartialFun#runWith
-rw-r--r--src/library/scala/Option.scala2
-rw-r--r--src/library/scala/collection/TraversableLike.scala2
-rw-r--r--src/library/scala/collection/TraversableOnce.scala6
-rw-r--r--src/library/scala/collection/immutable/Stream.scala13
-rw-r--r--src/library/scala/collection/parallel/RemainsIterator.scala3
-rw-r--r--src/library/scala/collection/parallel/mutable/ParArray.scala3
-rw-r--r--test/files/run/t6448.check32
-rw-r--r--test/files/run/t6448.scala61
8 files changed, 110 insertions, 12 deletions
diff --git a/src/library/scala/Option.scala b/src/library/scala/Option.scala
index ace018fbf9..4b071166c7 100644
--- a/src/library/scala/Option.scala
+++ b/src/library/scala/Option.scala
@@ -256,7 +256,7 @@ sealed abstract class Option[+A] extends Product with Serializable {
* value (if possible), or $none.
*/
@inline final def collect[B](pf: PartialFunction[A, B]): Option[B] =
- if (!isEmpty && pf.isDefinedAt(this.get)) Some(pf(this.get)) else None
+ if (!isEmpty) pf.lift(this.get) else None
/** Returns this $option if it is nonempty,
* otherwise return the result of evaluating `alternative`.
diff --git a/src/library/scala/collection/TraversableLike.scala b/src/library/scala/collection/TraversableLike.scala
index fc7202b96b..5ea1faa0df 100644
--- a/src/library/scala/collection/TraversableLike.scala
+++ b/src/library/scala/collection/TraversableLike.scala
@@ -275,7 +275,7 @@ trait TraversableLike[+A, +Repr] extends Any
def collect[B, That](pf: PartialFunction[A, B])(implicit bf: CanBuildFrom[Repr, B, That]): That = {
val b = bf(repr)
- for (x <- this) if (pf.isDefinedAt(x)) b += pf(x)
+ foreach(pf.runWith(b += _))
b.result
}
diff --git a/src/library/scala/collection/TraversableOnce.scala b/src/library/scala/collection/TraversableOnce.scala
index d6936979e1..82cf1d1198 100644
--- a/src/library/scala/collection/TraversableOnce.scala
+++ b/src/library/scala/collection/TraversableOnce.scala
@@ -128,10 +128,8 @@ trait TraversableOnce[+A] extends Any with GenTraversableOnce[A] {
* @example `Seq("a", 1, 5L).collectFirst({ case x: Int => x*10 }) = Some(10)`
*/
def collectFirst[B](pf: PartialFunction[A, B]): Option[B] = {
- for (x <- self.toIterator) { // make sure to use an iterator or `seq`
- if (pf isDefinedAt x)
- return Some(pf(x))
- }
+ // make sure to use an iterator or `seq`
+ self.toIterator.foreach(pf.runWith(b => return Some(b)))
None
}
diff --git a/src/library/scala/collection/immutable/Stream.scala b/src/library/scala/collection/immutable/Stream.scala
index a5c6254aed..e520ddc27c 100644
--- a/src/library/scala/collection/immutable/Stream.scala
+++ b/src/library/scala/collection/immutable/Stream.scala
@@ -385,12 +385,17 @@ self =>
// 1) stackoverflows (could be achieved with tailrec, too)
// 2) out of memory errors for big streams (`this` reference can be eliminated from the stack)
var rest: Stream[A] = this
- while (rest.nonEmpty && !pf.isDefinedAt(rest.head)) rest = rest.tail
+
+ // Avoids calling both `pf.isDefined` and `pf.apply`.
+ var newHead: B = null.asInstanceOf[B]
+ val runWith = pf.runWith((b: B) => newHead = b)
+
+ while (rest.nonEmpty && !runWith(rest.head)) rest = rest.tail
// without the call to the companion object, a thunk is created for the tail of the new stream,
// and the closure of the thunk will reference `this`
if (rest.isEmpty) Stream.Empty.asInstanceOf[That]
- else Stream.collectedTail(rest, pf, bf).asInstanceOf[That]
+ else Stream.collectedTail(newHead, rest, pf, bf).asInstanceOf[That]
}
}
@@ -1170,8 +1175,8 @@ object Stream extends SeqFactory[Stream] {
cons(stream.head, stream.tail filter p)
}
- private[immutable] def collectedTail[A, B, That](stream: Stream[A], pf: PartialFunction[A, B], bf: CanBuildFrom[Stream[A], B, That]) = {
- cons(pf(stream.head), stream.tail.collect(pf)(bf).asInstanceOf[Stream[B]])
+ private[immutable] def collectedTail[A, B, That](head: B, stream: Stream[A], pf: PartialFunction[A, B], bf: CanBuildFrom[Stream[A], B, That]) = {
+ cons(head, stream.tail.collect(pf)(bf).asInstanceOf[Stream[B]])
}
}
diff --git a/src/library/scala/collection/parallel/RemainsIterator.scala b/src/library/scala/collection/parallel/RemainsIterator.scala
index 3150b0d763..732ebc3709 100644
--- a/src/library/scala/collection/parallel/RemainsIterator.scala
+++ b/src/library/scala/collection/parallel/RemainsIterator.scala
@@ -123,9 +123,10 @@ private[collection] trait AugmentedIterableIterator[+T] extends RemainsIterator[
def collect2combiner[S, That](pf: PartialFunction[T, S], cb: Combiner[S, That]): Combiner[S, That] = {
//val cb = pbf(repr)
+ val runWith = pf.runWith(cb += _)
while (hasNext) {
val curr = next
- if (pf.isDefinedAt(curr)) cb += pf(curr)
+ runWith(curr)
}
cb
}
diff --git a/src/library/scala/collection/parallel/mutable/ParArray.scala b/src/library/scala/collection/parallel/mutable/ParArray.scala
index 33b2f0b29c..db4a59f6ba 100644
--- a/src/library/scala/collection/parallel/mutable/ParArray.scala
+++ b/src/library/scala/collection/parallel/mutable/ParArray.scala
@@ -405,9 +405,10 @@ self =>
private def collect2combiner_quick[S, That](pf: PartialFunction[T, S], a: Array[Any], cb: Builder[S, That], ntil: Int, from: Int) {
var j = from
+ val runWith = pf.runWith(b => cb += b)
while (j < ntil) {
val curr = a(j).asInstanceOf[T]
- if (pf.isDefinedAt(curr)) cb += pf(curr)
+ runWith(curr)
j += 1
}
}
diff --git a/test/files/run/t6448.check b/test/files/run/t6448.check
new file mode 100644
index 0000000000..9401568319
--- /dev/null
+++ b/test/files/run/t6448.check
@@ -0,0 +1,32 @@
+
+=List.collect=
+f(1)
+f(2)
+List(1)
+
+=List.collectFirst=
+f(1)
+Some(1)
+
+=Option.collect=
+f(1)
+Some(1)
+
+=Option.collect=
+f(2)
+None
+
+=Stream.collect=
+f(1)
+f(2)
+List(1)
+
+=Stream.collectFirst=
+f(1)
+Some(1)
+
+=ParVector.collect=
+(ParVector(1),2)
+
+=ParArray.collect=
+(ParArray(1),2)
diff --git a/test/files/run/t6448.scala b/test/files/run/t6448.scala
new file mode 100644
index 0000000000..4d1528e500
--- /dev/null
+++ b/test/files/run/t6448.scala
@@ -0,0 +1,61 @@
+// Tests to show that various `collect` functions avoid calling
+// both `PartialFunction#isDefinedAt` and `PartialFunction#apply`.
+//
+object Test {
+ def f(i: Int) = { println("f(" + i + ")"); true }
+ class Counter {
+ var count = 0
+ def apply(i: Int) = synchronized {count += 1; true}
+ }
+
+ def testing(label: String)(body: => Any) {
+ println(s"\n=$label=")
+ println(body)
+ }
+
+ def main(args: Array[String]) {
+ testing("List.collect")(List(1, 2) collect { case x if f(x) && x < 2 => x})
+ testing("List.collectFirst")(List(1, 2) collectFirst { case x if f(x) && x < 2 => x})
+ testing("Option.collect")(Some(1) collect { case x if f(x) && x < 2 => x})
+ testing("Option.collect")(Some(2) collect { case x if f(x) && x < 2 => x})
+ testing("Stream.collect")((Stream(1, 2).collect { case x if f(x) && x < 2 => x}).toList)
+ testing("Stream.collectFirst")(Stream.continually(1) collectFirst { case x if f(x) && x < 2 => x})
+
+ import collection.parallel.ParIterable
+ import collection.parallel.immutable.ParVector
+ import collection.parallel.mutable.ParArray
+ testing("ParVector.collect") {
+ val counter = new Counter()
+ (ParVector(1, 2) collect { case x if counter(x) && x < 2 => x}, counter.synchronized(counter.count))
+ }
+
+ testing("ParArray.collect") {
+ val counter = new Counter()
+ (ParArray(1, 2) collect { case x if counter(x) && x < 2 => x}, counter.synchronized(counter.count))
+ }
+
+ object PendingTests {
+ testing("Iterator.collect")((Iterator(1, 2) collect { case x if f(x) && x < 2 => x}).toList)
+
+ testing("List.view.collect")((List(1, 2).view collect { case x if f(x) && x < 2 => x}).force)
+
+ // This would do the trick in Future.collect, but I haven't added this yet as there is a tradeoff
+ // with extra allocations to consider.
+ //
+ // pf.lift(v) match {
+ // case Some(x) => p success x
+ // case None => fail(v)
+ // }
+ testing("Future.collect") {
+ import concurrent.ExecutionContext.Implicits.global
+ import concurrent.Await
+ import concurrent.duration.Duration
+ val result = concurrent.future(1) collect { case x if f(x) => x}
+ Await.result(result, Duration.Inf)
+ }
+
+ // TODO Future.{onSuccess, onFailure, recoverWith, andThen}
+ }
+
+ }
+}