From ad55804547c79d75e981f260cdb40dfba2ae8bfc Mon Sep 17 00:00:00 2001 From: Aleksandar Pokopec Date: Tue, 14 Sep 2010 14:13:44 +0000 Subject: Improved parallel scan performance further. --- .../collection/parallel/ParIterableLike.scala | 101 +++++++++++++++++++-- src/library/scala/collection/parallel/Tasks.scala | 20 ++++ .../collection/parallel/mutable/ParArray.scala | 14 +-- .../benchmarks/parallel_array/ScanLight.scala | 2 +- 4 files changed, 122 insertions(+), 15 deletions(-) diff --git a/src/library/scala/collection/parallel/ParIterableLike.scala b/src/library/scala/collection/parallel/ParIterableLike.scala index 02c3302c3c..1cc27fb186 100644 --- a/src/library/scala/collection/parallel/ParIterableLike.scala +++ b/src/library/scala/collection/parallel/ParIterableLike.scala @@ -546,9 +546,9 @@ extends IterableLike[T, Repr] def scan[U >: T, That](z: U)(op: (U, U) => U)(implicit cbf: CanCombineFrom[Repr, U, That]): That = { val array = new Array[Any](size + 1) array(0) = z - executeAndWaitResult(new PartialScan[U, Any](z, op, 1, size, array, parallelIterator) mapResult { st => - executeAndWaitResult(new ApplyScanTree[U, Any](None, op, st, array) mapResult { u => - executeAndWaitResult(new FromArray(array, 0, size + 1, cbf) mapResult { _.result }); + executeAndWaitResult(new BuildScanTree[U, Any](z, op, 1, size, array, parallelIterator) mapResult { st => + executeAndWaitResult(new ScanWithScanTree[U, Any](Some(z), op, st, array, array) mapResult { u => + executeAndWaitResult(new FromArray(array, 0, size + 1, cbf) mapResult { _.result }) }) }) } @@ -948,9 +948,12 @@ extends IterableLike[T, Repr] var value: U = _ var left: ScanTree[U] = null var right: ScanTree[U] = null - var shouldApply = true + @volatile var chunkFinished = false + var activeScan: () => Unit = null + def isApplying = activeScan ne null def isLeaf = (left eq null) && (right eq null) + def shouldApply = !chunkFinished && !isApplying def applyToInterval[A >: U](elem: U, op: (U, U) => U, array: Array[A]) = { //executeAndWait(new ApplyToArray(elem, op, from, len, array)) var i = from @@ -960,6 +963,18 @@ extends IterableLike[T, Repr] i += 1 } } + def scanInterval[A >: U](elem: U, op: (U, U) => U, srcA: Array[A], destA: Array[A]) = { + val src = srcA.asInstanceOf[Array[Any]] + val dest = destA.asInstanceOf[Array[Any]] + var last = elem + var i = from + val until = from + len + while (i < until) { + last = op(last, src(i - 1).asInstanceOf[U]) + dest(i) = last + i += 1 + } + } def pushDown(v: U, op: (U, U) => U) { value = op(v, value) if (left ne null) left.pushDown(v, op) @@ -980,10 +995,17 @@ extends IterableLike[T, Repr] extends Accessor[ScanTree[U], PartialScan[U, A]] { var result: ScanTree[U] = null def leaf(prev: Option[ScanTree[U]]) = { - pit.scanToArray(z, op, array, from) result = new ScanTree[U](from, len) + val firstElem = prev match { + case Some(st) => if (st.chunkFinished) { + result.chunkFinished = true + st.value + } else z + case None => z + } + pit.scanToArray(firstElem, op, array, from) result.value = array(from + len - 1).asInstanceOf[U] - if (from == 1) result.shouldApply = false + if (from == 1) result.chunkFinished = true } def newSubtask(p: ParIterator) = unsupported override def split = { @@ -994,6 +1016,7 @@ extends IterableLike[T, Repr] } } override def merge(that: PartialScan[U, A]) = { + // create scan tree node val left = result val right = that.result val ns = new ScanTree[U](left.from, left.len + right.len) @@ -1002,6 +1025,14 @@ extends IterableLike[T, Repr] ns.value = op(left.value, right.value) ns.pushDownOnRight(left.value, op) + // commented out - didn't result in performance gains + // check if left tree has finished cumulating + // if it did, start the cumulate step in the right subtree before returning to the root + //if (left.chunkFinished) { + // right.activeScan = execute(new ApplyScanTree[U](Some(left.value), op, right, array)) + //} + + // set result result = ns } } @@ -1010,13 +1041,13 @@ extends IterableLike[T, Repr] extends super.Task[Unit, ApplyScanTree[U, A]] { var result = (); def leaf(prev: Option[Unit]) = if (st.shouldApply) apply(st, first.get) - private def apply(st: ScanTree[U], elem: U) { + private def apply(st: ScanTree[U], elem: U): Unit = if (st.shouldApply) { if (st.isLeaf) st.applyToInterval(elem, op, array) else { apply(st.left, elem) apply(st.right, st.left.value) } - } + } else if (st.isApplying) st.activeScan() def split = collection.mutable.ArrayBuffer( new ApplyScanTree(first, op, st.left, array), new ApplyScanTree(Some(st.left.value), op, st.right, array) @@ -1046,6 +1077,60 @@ extends IterableLike[T, Repr] } } + protected[this] class BuildScanTree[U >: T, A >: U](z: U, op: (U, U) => U, val from: Int, val len: Int, array: Array[A], val pit: ParIterator) + extends Accessor[ScanTree[U], BuildScanTree[U, A]] { + var result: ScanTree[U] = null + def leaf(prev: Option[ScanTree[U]]) = if ((prev != None && prev.get.chunkFinished) || from == 1) { + val prevElem = if (from == 1) z else prev.get.value + result = new ScanTree[U](from, len) + pit.scanToArray(prevElem, op, array, from) + result.value = array(from + len - 1).asInstanceOf[U] + result.chunkFinished = true + } else { + result = new ScanTree[U](from, len) + result.value = pit.fold(z)(op) + } + def newSubtask(p: ParIterator) = unsupported + override def split = { + val pits = pit.split + for ((p, untilp) <- pits zip pits.scanLeft(0)(_ + _.remaining); if untilp < len) yield { + val plen = p.remaining min (len - untilp) + new BuildScanTree[U, A](z, op, from + untilp, plen, array, p) + } + } + override def merge(that: BuildScanTree[U, A]) = { + // create scan tree node + val left = result + val right = that.result + val ns = new ScanTree[U](left.from, left.len + right.len) + ns.left = left + ns.right = right + ns.value = op(left.value, right.value) + ns.pushDownOnRight(left.value, op) + + // set result + result = ns + } + } + + protected[this] class ScanWithScanTree[U >: T, A >: U](first: Option[U], op: (U, U) => U, st: ScanTree[U], src: Array[A], dest: Array[A]) + extends super.Task[Unit, ScanWithScanTree[U, A]] { + var result = (); + def leaf(prev: Option[Unit]) = scan(st, first.get) + private def scan(st: ScanTree[U], elem: U): Unit = if (!st.chunkFinished) { + if (st.isLeaf) st.scanInterval(elem, op, src, dest) + else { + scan(st.left, elem) + scan(st.right, st.left.value) + } + } + def split = collection.mutable.ArrayBuffer( + new ScanWithScanTree(first, op, st.left, src, dest), + new ScanWithScanTree(Some(st.left.value), op, st.right, src, dest) + ) + def shouldSplitFurther = (st.left ne null) && (st.right ne null) + } + protected[this] class FromArray[S, A, That](array: Array[A], from: Int, len: Int, cbf: CanCombineFrom[Repr, S, That]) extends super.Task[Combiner[S, That], FromArray[S, A, That]] { var result: Result = null diff --git a/src/library/scala/collection/parallel/Tasks.scala b/src/library/scala/collection/parallel/Tasks.scala index cbc85e43de..dabb4e813f 100644 --- a/src/library/scala/collection/parallel/Tasks.scala +++ b/src/library/scala/collection/parallel/Tasks.scala @@ -66,6 +66,9 @@ trait Tasks { var environment: ExecutionEnvironment + /** Executes a task and returns a future. */ + def execute[R, Tp](fjtask: TaskType[R, Tp]): () => R + /** Executes a task and waits for it to finish. */ def executeAndWait[R, Tp](task: TaskType[R, Tp]) @@ -171,6 +174,23 @@ trait ForkJoinTasks extends Tasks with HavingForkJoinPool { def forkJoinPool: ForkJoinPool = environment var environment = ForkJoinTasks.defaultForkJoinPool + /** Executes a task and does not wait for it to finish - instead returns a future. + * + * $fjdispatch + */ + def execute[R, Tp](fjtask: Task[R, Tp]): () => R = { + if (currentThread.isInstanceOf[ForkJoinWorkerThread]) { + fjtask.fork + } else { + forkJoinPool.execute(fjtask) + } + + () => { + fjtask.join + fjtask.result + } + } + /** Executes a task on a fork/join pool and waits for it to finish. * * $fjdispatch diff --git a/src/library/scala/collection/parallel/mutable/ParArray.scala b/src/library/scala/collection/parallel/mutable/ParArray.scala index aa0eb03365..1e1d3452d3 100644 --- a/src/library/scala/collection/parallel/mutable/ParArray.scala +++ b/src/library/scala/collection/parallel/mutable/ParArray.scala @@ -159,6 +159,8 @@ extends ParSeq[T] sum } + override def fold[U >: T](z: U)(op: (U, U) => U): U = foldLeft[U](z)(op) + def aggregate[S](z: S)(seqop: (S, T) => S, combop: (S, S) => S): S = foldLeft[S](z)(seqop) override def sum[U >: T](implicit num: Numeric[U]): U = { @@ -552,13 +554,13 @@ extends ParSeq[T] targetarr(0) = z // do a parallel prefix scan - executeAndWait(new PartialScan[U, Any](z, op, 1, size, targetarr, parallelIterator) mapResult { st => - //println("-----------------------") - //println(targetarr.toList) - //st.printTree - executeAndWaitResult(new ApplyScanTree[U, Any](None, op, st, targetarr)) + executeAndWait(new BuildScanTree[U, Any](z, op, 1, size, targetarr, parallelIterator) mapResult { st => + // println("-----------------------") + // println(targetarr.toList) + // st.printTree + executeAndWaitResult(new ScanWithScanTree[U, Any](Some(z), op, st, array, targetarr)) }) - //println(targetarr.toList) + // println(targetarr.toList) // wrap the array into a parallel array (new ParArray[U](targarrseq)).asInstanceOf[That] diff --git a/test/benchmarks/src/scala/collection/parallel/benchmarks/parallel_array/ScanLight.scala b/test/benchmarks/src/scala/collection/parallel/benchmarks/parallel_array/ScanLight.scala index 02a52ac207..85bfabfea7 100644 --- a/test/benchmarks/src/scala/collection/parallel/benchmarks/parallel_array/ScanLight.scala +++ b/test/benchmarks/src/scala/collection/parallel/benchmarks/parallel_array/ScanLight.scala @@ -16,7 +16,7 @@ object ScanLight extends Companion { } def operation(a: Cont, b: Cont) = { val m = if (a.in < 0) 1 else 0 - new Cont(a.in + b.in + m * (0 until 10).reduceLeft(_ + _)) + new Cont(a.in + b.in + m * (0 until 2).reduceLeft(_ + _)) } } -- cgit v1.2.3