summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAleksandar Pokopec <aleksandar.prokopec@epfl.ch>2010-09-14 14:13:44 +0000
committerAleksandar Pokopec <aleksandar.prokopec@epfl.ch>2010-09-14 14:13:44 +0000
commitad55804547c79d75e981f260cdb40dfba2ae8bfc (patch)
tree7e2ac929842be3c72ca7217010972125ffe1b51b
parentce755fb08dbd3bdfe202d3ea52a8e599d846b4c1 (diff)
downloadscala-ad55804547c79d75e981f260cdb40dfba2ae8bfc.tar.gz
scala-ad55804547c79d75e981f260cdb40dfba2ae8bfc.tar.bz2
scala-ad55804547c79d75e981f260cdb40dfba2ae8bfc.zip
Improved parallel scan performance further.
-rw-r--r--src/library/scala/collection/parallel/ParIterableLike.scala101
-rw-r--r--src/library/scala/collection/parallel/Tasks.scala20
-rw-r--r--src/library/scala/collection/parallel/mutable/ParArray.scala14
-rw-r--r--test/benchmarks/src/scala/collection/parallel/benchmarks/parallel_array/ScanLight.scala2
4 files changed, 122 insertions, 15 deletions
diff --git a/src/library/scala/collection/parallel/ParIterableLike.scala b/src/library/scala/collection/parallel/ParIterableLike.scala
index 02c3302c3c..1cc27fb186 100644
--- a/src/library/scala/collection/parallel/ParIterableLike.scala
+++ b/src/library/scala/collection/parallel/ParIterableLike.scala
@@ -546,9 +546,9 @@ extends IterableLike[T, Repr]
def scan[U >: T, That](z: U)(op: (U, U) => U)(implicit cbf: CanCombineFrom[Repr, U, That]): That = {
val array = new Array[Any](size + 1)
array(0) = z
- executeAndWaitResult(new PartialScan[U, Any](z, op, 1, size, array, parallelIterator) mapResult { st =>
- executeAndWaitResult(new ApplyScanTree[U, Any](None, op, st, array) mapResult { u =>
- executeAndWaitResult(new FromArray(array, 0, size + 1, cbf) mapResult { _.result });
+ executeAndWaitResult(new BuildScanTree[U, Any](z, op, 1, size, array, parallelIterator) mapResult { st =>
+ executeAndWaitResult(new ScanWithScanTree[U, Any](Some(z), op, st, array, array) mapResult { u =>
+ executeAndWaitResult(new FromArray(array, 0, size + 1, cbf) mapResult { _.result })
})
})
}
@@ -948,9 +948,12 @@ extends IterableLike[T, Repr]
var value: U = _
var left: ScanTree[U] = null
var right: ScanTree[U] = null
- var shouldApply = true
+ @volatile var chunkFinished = false
+ var activeScan: () => Unit = null
+ def isApplying = activeScan ne null
def isLeaf = (left eq null) && (right eq null)
+ def shouldApply = !chunkFinished && !isApplying
def applyToInterval[A >: U](elem: U, op: (U, U) => U, array: Array[A]) = {
//executeAndWait(new ApplyToArray(elem, op, from, len, array))
var i = from
@@ -960,6 +963,18 @@ extends IterableLike[T, Repr]
i += 1
}
}
+ def scanInterval[A >: U](elem: U, op: (U, U) => U, srcA: Array[A], destA: Array[A]) = {
+ val src = srcA.asInstanceOf[Array[Any]]
+ val dest = destA.asInstanceOf[Array[Any]]
+ var last = elem
+ var i = from
+ val until = from + len
+ while (i < until) {
+ last = op(last, src(i - 1).asInstanceOf[U])
+ dest(i) = last
+ i += 1
+ }
+ }
def pushDown(v: U, op: (U, U) => U) {
value = op(v, value)
if (left ne null) left.pushDown(v, op)
@@ -980,10 +995,17 @@ extends IterableLike[T, Repr]
extends Accessor[ScanTree[U], PartialScan[U, A]] {
var result: ScanTree[U] = null
def leaf(prev: Option[ScanTree[U]]) = {
- pit.scanToArray(z, op, array, from)
result = new ScanTree[U](from, len)
+ val firstElem = prev match {
+ case Some(st) => if (st.chunkFinished) {
+ result.chunkFinished = true
+ st.value
+ } else z
+ case None => z
+ }
+ pit.scanToArray(firstElem, op, array, from)
result.value = array(from + len - 1).asInstanceOf[U]
- if (from == 1) result.shouldApply = false
+ if (from == 1) result.chunkFinished = true
}
def newSubtask(p: ParIterator) = unsupported
override def split = {
@@ -994,6 +1016,7 @@ extends IterableLike[T, Repr]
}
}
override def merge(that: PartialScan[U, A]) = {
+ // create scan tree node
val left = result
val right = that.result
val ns = new ScanTree[U](left.from, left.len + right.len)
@@ -1002,6 +1025,14 @@ extends IterableLike[T, Repr]
ns.value = op(left.value, right.value)
ns.pushDownOnRight(left.value, op)
+ // commented out - didn't result in performance gains
+ // check if left tree has finished cumulating
+ // if it did, start the cumulate step in the right subtree before returning to the root
+ //if (left.chunkFinished) {
+ // right.activeScan = execute(new ApplyScanTree[U](Some(left.value), op, right, array))
+ //}
+
+ // set result
result = ns
}
}
@@ -1010,13 +1041,13 @@ extends IterableLike[T, Repr]
extends super.Task[Unit, ApplyScanTree[U, A]] {
var result = ();
def leaf(prev: Option[Unit]) = if (st.shouldApply) apply(st, first.get)
- private def apply(st: ScanTree[U], elem: U) {
+ private def apply(st: ScanTree[U], elem: U): Unit = if (st.shouldApply) {
if (st.isLeaf) st.applyToInterval(elem, op, array)
else {
apply(st.left, elem)
apply(st.right, st.left.value)
}
- }
+ } else if (st.isApplying) st.activeScan()
def split = collection.mutable.ArrayBuffer(
new ApplyScanTree(first, op, st.left, array),
new ApplyScanTree(Some(st.left.value), op, st.right, array)
@@ -1046,6 +1077,60 @@ extends IterableLike[T, Repr]
}
}
+ protected[this] class BuildScanTree[U >: T, A >: U](z: U, op: (U, U) => U, val from: Int, val len: Int, array: Array[A], val pit: ParIterator)
+ extends Accessor[ScanTree[U], BuildScanTree[U, A]] {
+ var result: ScanTree[U] = null
+ def leaf(prev: Option[ScanTree[U]]) = if ((prev != None && prev.get.chunkFinished) || from == 1) {
+ val prevElem = if (from == 1) z else prev.get.value
+ result = new ScanTree[U](from, len)
+ pit.scanToArray(prevElem, op, array, from)
+ result.value = array(from + len - 1).asInstanceOf[U]
+ result.chunkFinished = true
+ } else {
+ result = new ScanTree[U](from, len)
+ result.value = pit.fold(z)(op)
+ }
+ def newSubtask(p: ParIterator) = unsupported
+ override def split = {
+ val pits = pit.split
+ for ((p, untilp) <- pits zip pits.scanLeft(0)(_ + _.remaining); if untilp < len) yield {
+ val plen = p.remaining min (len - untilp)
+ new BuildScanTree[U, A](z, op, from + untilp, plen, array, p)
+ }
+ }
+ override def merge(that: BuildScanTree[U, A]) = {
+ // create scan tree node
+ val left = result
+ val right = that.result
+ val ns = new ScanTree[U](left.from, left.len + right.len)
+ ns.left = left
+ ns.right = right
+ ns.value = op(left.value, right.value)
+ ns.pushDownOnRight(left.value, op)
+
+ // set result
+ result = ns
+ }
+ }
+
+ protected[this] class ScanWithScanTree[U >: T, A >: U](first: Option[U], op: (U, U) => U, st: ScanTree[U], src: Array[A], dest: Array[A])
+ extends super.Task[Unit, ScanWithScanTree[U, A]] {
+ var result = ();
+ def leaf(prev: Option[Unit]) = scan(st, first.get)
+ private def scan(st: ScanTree[U], elem: U): Unit = if (!st.chunkFinished) {
+ if (st.isLeaf) st.scanInterval(elem, op, src, dest)
+ else {
+ scan(st.left, elem)
+ scan(st.right, st.left.value)
+ }
+ }
+ def split = collection.mutable.ArrayBuffer(
+ new ScanWithScanTree(first, op, st.left, src, dest),
+ new ScanWithScanTree(Some(st.left.value), op, st.right, src, dest)
+ )
+ def shouldSplitFurther = (st.left ne null) && (st.right ne null)
+ }
+
protected[this] class FromArray[S, A, That](array: Array[A], from: Int, len: Int, cbf: CanCombineFrom[Repr, S, That])
extends super.Task[Combiner[S, That], FromArray[S, A, That]] {
var result: Result = null
diff --git a/src/library/scala/collection/parallel/Tasks.scala b/src/library/scala/collection/parallel/Tasks.scala
index cbc85e43de..dabb4e813f 100644
--- a/src/library/scala/collection/parallel/Tasks.scala
+++ b/src/library/scala/collection/parallel/Tasks.scala
@@ -66,6 +66,9 @@ trait Tasks {
var environment: ExecutionEnvironment
+ /** Executes a task and returns a future. */
+ def execute[R, Tp](fjtask: TaskType[R, Tp]): () => R
+
/** Executes a task and waits for it to finish. */
def executeAndWait[R, Tp](task: TaskType[R, Tp])
@@ -171,6 +174,23 @@ trait ForkJoinTasks extends Tasks with HavingForkJoinPool {
def forkJoinPool: ForkJoinPool = environment
var environment = ForkJoinTasks.defaultForkJoinPool
+ /** Executes a task and does not wait for it to finish - instead returns a future.
+ *
+ * $fjdispatch
+ */
+ def execute[R, Tp](fjtask: Task[R, Tp]): () => R = {
+ if (currentThread.isInstanceOf[ForkJoinWorkerThread]) {
+ fjtask.fork
+ } else {
+ forkJoinPool.execute(fjtask)
+ }
+
+ () => {
+ fjtask.join
+ fjtask.result
+ }
+ }
+
/** Executes a task on a fork/join pool and waits for it to finish.
*
* $fjdispatch
diff --git a/src/library/scala/collection/parallel/mutable/ParArray.scala b/src/library/scala/collection/parallel/mutable/ParArray.scala
index aa0eb03365..1e1d3452d3 100644
--- a/src/library/scala/collection/parallel/mutable/ParArray.scala
+++ b/src/library/scala/collection/parallel/mutable/ParArray.scala
@@ -159,6 +159,8 @@ extends ParSeq[T]
sum
}
+ override def fold[U >: T](z: U)(op: (U, U) => U): U = foldLeft[U](z)(op)
+
def aggregate[S](z: S)(seqop: (S, T) => S, combop: (S, S) => S): S = foldLeft[S](z)(seqop)
override def sum[U >: T](implicit num: Numeric[U]): U = {
@@ -552,13 +554,13 @@ extends ParSeq[T]
targetarr(0) = z
// do a parallel prefix scan
- executeAndWait(new PartialScan[U, Any](z, op, 1, size, targetarr, parallelIterator) mapResult { st =>
- //println("-----------------------")
- //println(targetarr.toList)
- //st.printTree
- executeAndWaitResult(new ApplyScanTree[U, Any](None, op, st, targetarr))
+ executeAndWait(new BuildScanTree[U, Any](z, op, 1, size, targetarr, parallelIterator) mapResult { st =>
+ // println("-----------------------")
+ // println(targetarr.toList)
+ // st.printTree
+ executeAndWaitResult(new ScanWithScanTree[U, Any](Some(z), op, st, array, targetarr))
})
- //println(targetarr.toList)
+ // println(targetarr.toList)
// wrap the array into a parallel array
(new ParArray[U](targarrseq)).asInstanceOf[That]
diff --git a/test/benchmarks/src/scala/collection/parallel/benchmarks/parallel_array/ScanLight.scala b/test/benchmarks/src/scala/collection/parallel/benchmarks/parallel_array/ScanLight.scala
index 02a52ac207..85bfabfea7 100644
--- a/test/benchmarks/src/scala/collection/parallel/benchmarks/parallel_array/ScanLight.scala
+++ b/test/benchmarks/src/scala/collection/parallel/benchmarks/parallel_array/ScanLight.scala
@@ -16,7 +16,7 @@ object ScanLight extends Companion {
}
def operation(a: Cont, b: Cont) = {
val m = if (a.in < 0) 1 else 0
- new Cont(a.in + b.in + m * (0 until 10).reduceLeft(_ + _))
+ new Cont(a.in + b.in + m * (0 until 2).reduceLeft(_ + _))
}
}