diff options
-rw-r--r-- | core/src/main/java/org/apache/spark/util/collection/TimSort.java (renamed from core/src/main/java/org/apache/spark/util/collection/Sorter.java) | 77 | ||||
-rw-r--r-- | core/src/main/scala/org/apache/spark/util/Utils.scala | 26 | ||||
-rw-r--r-- | core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala | 41 | ||||
-rw-r--r-- | core/src/main/scala/org/apache/spark/util/collection/Sorter.scala | 39 | ||||
-rw-r--r-- | core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala | 8 | ||||
-rw-r--r-- | core/src/test/scala/org/apache/spark/util/UtilsSuite.scala | 11 | ||||
-rw-r--r-- | core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala | 210 | ||||
-rw-r--r-- | project/MimaExcludes.scala | 4 |
8 files changed, 310 insertions, 106 deletions
diff --git a/core/src/main/java/org/apache/spark/util/collection/Sorter.java b/core/src/main/java/org/apache/spark/util/collection/TimSort.java index 64ad18c0e4..409e1a41c5 100644 --- a/core/src/main/java/org/apache/spark/util/collection/Sorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/TimSort.java @@ -20,18 +20,25 @@ package org.apache.spark.util.collection; import java.util.Comparator; /** - * A port of the Android Timsort class, which utilizes a "stable, adaptive, iterative mergesort." + * A port of the Android TimSort class, which utilizes a "stable, adaptive, iterative mergesort." * See the method comment on sort() for more details. * * This has been kept in Java with the original style in order to match very closely with the - * Anroid source code, and thus be easy to verify correctness. + * Android source code, and thus be easy to verify correctness. The class is package private. We put + * a simple Scala wrapper {@link org.apache.spark.util.collection.Sorter}, which is available to + * package org.apache.spark. * * The purpose of the port is to generalize the interface to the sort to accept input data formats * besides simple arrays where every element is sorted individually. For instance, the AppendOnlyMap * uses this to sort an Array with alternating elements of the form [key, value, key, value]. * This generalization comes with minimal overhead -- see SortDataFormat for more information. + * + * We allow key reuse to prevent creating many key objects -- see SortDataFormat. + * + * @see org.apache.spark.util.collection.SortDataFormat + * @see org.apache.spark.util.collection.Sorter */ -class Sorter<K, Buffer> { +class TimSort<K, Buffer> { /** * This is the minimum sized sequence that will be merged. Shorter @@ -54,7 +61,7 @@ class Sorter<K, Buffer> { private final SortDataFormat<K, Buffer> s; - public Sorter(SortDataFormat<K, Buffer> sortDataFormat) { + public TimSort(SortDataFormat<K, Buffer> sortDataFormat) { this.s = sortDataFormat; } @@ -91,7 +98,7 @@ class Sorter<K, Buffer> { * * @author Josh Bloch */ - void sort(Buffer a, int lo, int hi, Comparator<? super K> c) { + public void sort(Buffer a, int lo, int hi, Comparator<? super K> c) { assert c != null; int nRemaining = hi - lo; @@ -162,10 +169,13 @@ class Sorter<K, Buffer> { if (start == lo) start++; + K key0 = s.newKey(); + K key1 = s.newKey(); + Buffer pivotStore = s.allocate(1); for ( ; start < hi; start++) { s.copyElement(a, start, pivotStore, 0); - K pivot = s.getKey(pivotStore, 0); + K pivot = s.getKey(pivotStore, 0, key0); // Set left (and right) to the index where a[start] (pivot) belongs int left = lo; @@ -178,7 +188,7 @@ class Sorter<K, Buffer> { */ while (left < right) { int mid = (left + right) >>> 1; - if (c.compare(pivot, s.getKey(a, mid)) < 0) + if (c.compare(pivot, s.getKey(a, mid, key1)) < 0) right = mid; else left = mid + 1; @@ -235,13 +245,16 @@ class Sorter<K, Buffer> { if (runHi == hi) return 1; + K key0 = s.newKey(); + K key1 = s.newKey(); + // Find end of run, and reverse range if descending - if (c.compare(s.getKey(a, runHi++), s.getKey(a, lo)) < 0) { // Descending - while (runHi < hi && c.compare(s.getKey(a, runHi), s.getKey(a, runHi - 1)) < 0) + if (c.compare(s.getKey(a, runHi++, key0), s.getKey(a, lo, key1)) < 0) { // Descending + while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) < 0) runHi++; reverseRange(a, lo, runHi); } else { // Ascending - while (runHi < hi && c.compare(s.getKey(a, runHi), s.getKey(a, runHi - 1)) >= 0) + while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0) runHi++; } @@ -468,11 +481,13 @@ class Sorter<K, Buffer> { } stackSize--; + K key0 = s.newKey(); + /* * Find where the first element of run2 goes in run1. Prior elements * in run1 can be ignored (because they're already in place). */ - int k = gallopRight(s.getKey(a, base2), a, base1, len1, 0, c); + int k = gallopRight(s.getKey(a, base2, key0), a, base1, len1, 0, c); assert k >= 0; base1 += k; len1 -= k; @@ -483,7 +498,7 @@ class Sorter<K, Buffer> { * Find where the last element of run1 goes in run2. Subsequent elements * in run2 can be ignored (because they're already in place). */ - len2 = gallopLeft(s.getKey(a, base1 + len1 - 1), a, base2, len2, len2 - 1, c); + len2 = gallopLeft(s.getKey(a, base1 + len1 - 1, key0), a, base2, len2, len2 - 1, c); assert len2 >= 0; if (len2 == 0) return; @@ -517,10 +532,12 @@ class Sorter<K, Buffer> { assert len > 0 && hint >= 0 && hint < len; int lastOfs = 0; int ofs = 1; - if (c.compare(key, s.getKey(a, base + hint)) > 0) { + K key0 = s.newKey(); + + if (c.compare(key, s.getKey(a, base + hint, key0)) > 0) { // Gallop right until a[base+hint+lastOfs] < key <= a[base+hint+ofs] int maxOfs = len - hint; - while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs)) > 0) { + while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key0)) > 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; if (ofs <= 0) // int overflow @@ -535,7 +552,7 @@ class Sorter<K, Buffer> { } else { // key <= a[base + hint] // Gallop left until a[base+hint-ofs] < key <= a[base+hint-lastOfs] final int maxOfs = hint + 1; - while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs)) <= 0) { + while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key0)) <= 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; if (ofs <= 0) // int overflow @@ -560,7 +577,7 @@ class Sorter<K, Buffer> { while (lastOfs < ofs) { int m = lastOfs + ((ofs - lastOfs) >>> 1); - if (c.compare(key, s.getKey(a, base + m)) > 0) + if (c.compare(key, s.getKey(a, base + m, key0)) > 0) lastOfs = m + 1; // a[base + m] < key else ofs = m; // key <= a[base + m] @@ -587,10 +604,12 @@ class Sorter<K, Buffer> { int ofs = 1; int lastOfs = 0; - if (c.compare(key, s.getKey(a, base + hint)) < 0) { + K key1 = s.newKey(); + + if (c.compare(key, s.getKey(a, base + hint, key1)) < 0) { // Gallop left until a[b+hint - ofs] <= key < a[b+hint - lastOfs] int maxOfs = hint + 1; - while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs)) < 0) { + while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key1)) < 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; if (ofs <= 0) // int overflow @@ -606,7 +625,7 @@ class Sorter<K, Buffer> { } else { // a[b + hint] <= key // Gallop right until a[b+hint + lastOfs] <= key < a[b+hint + ofs] int maxOfs = len - hint; - while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs)) >= 0) { + while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key1)) >= 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; if (ofs <= 0) // int overflow @@ -630,7 +649,7 @@ class Sorter<K, Buffer> { while (lastOfs < ofs) { int m = lastOfs + ((ofs - lastOfs) >>> 1); - if (c.compare(key, s.getKey(a, base + m)) < 0) + if (c.compare(key, s.getKey(a, base + m, key1)) < 0) ofs = m; // key < a[b + m] else lastOfs = m + 1; // a[b + m] <= key @@ -679,6 +698,9 @@ class Sorter<K, Buffer> { return; } + K key0 = s.newKey(); + K key1 = s.newKey(); + Comparator<? super K> c = this.c; // Use local variable for performance int minGallop = this.minGallop; // " " " " " outer: @@ -692,7 +714,7 @@ class Sorter<K, Buffer> { */ do { assert len1 > 1 && len2 > 0; - if (c.compare(s.getKey(a, cursor2), s.getKey(tmp, cursor1)) < 0) { + if (c.compare(s.getKey(a, cursor2, key0), s.getKey(tmp, cursor1, key1)) < 0) { s.copyElement(a, cursor2++, a, dest++); count2++; count1 = 0; @@ -714,7 +736,7 @@ class Sorter<K, Buffer> { */ do { assert len1 > 1 && len2 > 0; - count1 = gallopRight(s.getKey(a, cursor2), tmp, cursor1, len1, 0, c); + count1 = gallopRight(s.getKey(a, cursor2, key0), tmp, cursor1, len1, 0, c); if (count1 != 0) { s.copyRange(tmp, cursor1, a, dest, count1); dest += count1; @@ -727,7 +749,7 @@ class Sorter<K, Buffer> { if (--len2 == 0) break outer; - count2 = gallopLeft(s.getKey(tmp, cursor1), a, cursor2, len2, 0, c); + count2 = gallopLeft(s.getKey(tmp, cursor1, key0), a, cursor2, len2, 0, c); if (count2 != 0) { s.copyRange(a, cursor2, a, dest, count2); dest += count2; @@ -784,6 +806,9 @@ class Sorter<K, Buffer> { int cursor2 = len2 - 1; // Indexes into tmp array int dest = base2 + len2 - 1; // Indexes into a + K key0 = s.newKey(); + K key1 = s.newKey(); + // Move last element of first run and deal with degenerate cases s.copyElement(a, cursor1--, a, dest--); if (--len1 == 0) { @@ -811,7 +836,7 @@ class Sorter<K, Buffer> { */ do { assert len1 > 0 && len2 > 1; - if (c.compare(s.getKey(tmp, cursor2), s.getKey(a, cursor1)) < 0) { + if (c.compare(s.getKey(tmp, cursor2, key0), s.getKey(a, cursor1, key1)) < 0) { s.copyElement(a, cursor1--, a, dest--); count1++; count2 = 0; @@ -833,7 +858,7 @@ class Sorter<K, Buffer> { */ do { assert len1 > 0 && len2 > 1; - count1 = len1 - gallopRight(s.getKey(tmp, cursor2), a, base1, len1, len1 - 1, c); + count1 = len1 - gallopRight(s.getKey(tmp, cursor2, key0), a, base1, len1, len1 - 1, c); if (count1 != 0) { dest -= count1; cursor1 -= count1; @@ -846,7 +871,7 @@ class Sorter<K, Buffer> { if (--len2 == 1) break outer; - count2 = len2 - gallopLeft(s.getKey(a, cursor1), tmp, 0, len2, len2 - 1, c); + count2 = len2 - gallopLeft(s.getKey(a, cursor1, key0), tmp, 0, len2, len2 - 1, c); if (count2 != 0) { dest -= count2; cursor2 -= count2; diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 612eca308b..1e881da511 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1272,12 +1272,28 @@ private[spark] object Utils extends Logging { /** * Timing method based on iterations that permit JVM JIT optimization. * @param numIters number of iterations - * @param f function to be executed + * @param f function to be executed. If prepare is not None, the running time of each call to f + * must be an order of magnitude longer than one millisecond for accurate timing. + * @param prepare function to be executed before each call to f. Its running time doesn't count. + * @return the total time across all iterations (not couting preparation time) */ - def timeIt(numIters: Int)(f: => Unit): Long = { - val start = System.currentTimeMillis - times(numIters)(f) - System.currentTimeMillis - start + def timeIt(numIters: Int)(f: => Unit, prepare: Option[() => Unit] = None): Long = { + if (prepare.isEmpty) { + val start = System.currentTimeMillis + times(numIters)(f) + System.currentTimeMillis - start + } else { + var i = 0 + var sum = 0L + while (i < numIters) { + prepare.get.apply() + val start = System.currentTimeMillis + f + sum += System.currentTimeMillis - start + i += 1 + } + sum + } } /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala b/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala index ac1528969f..4f0bf8384a 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala @@ -27,33 +27,51 @@ import scala.reflect.ClassTag * Example format: an array of numbers, where each element is also the key. * See [[KVArraySortDataFormat]] for a more exciting format. * - * This trait extends Any to ensure it is universal (and thus compiled to a Java interface). + * Note: Declaring and instantiating multiple subclasses of this class would prevent JIT inlining + * overridden methods and hence decrease the shuffle performance. * * @tparam K Type of the sort key of each element * @tparam Buffer Internal data structure used by a particular format (e.g., Array[Int]). */ // TODO: Making Buffer a real trait would be a better abstraction, but adds some complexity. -private[spark] trait SortDataFormat[K, Buffer] extends Any { +private[spark] +abstract class SortDataFormat[K, Buffer] { + + /** + * Creates a new mutable key for reuse. This should be implemented if you want to override + * [[getKey(Buffer, Int, K)]]. + */ + def newKey(): K = null.asInstanceOf[K] + /** Return the sort key for the element at the given index. */ protected def getKey(data: Buffer, pos: Int): K + /** + * Returns the sort key for the element at the given index and reuse the input key if possible. + * The default implementation ignores the reuse parameter and invokes [[getKey(Buffer, Int]]. + * If you want to override this method, you must implement [[newKey()]]. + */ + def getKey(data: Buffer, pos: Int, reuse: K): K = { + getKey(data, pos) + } + /** Swap two elements. */ - protected def swap(data: Buffer, pos0: Int, pos1: Int): Unit + def swap(data: Buffer, pos0: Int, pos1: Int): Unit /** Copy a single element from src(srcPos) to dst(dstPos). */ - protected def copyElement(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int): Unit + def copyElement(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int): Unit /** * Copy a range of elements starting at src(srcPos) to dst, starting at dstPos. * Overlapping ranges are allowed. */ - protected def copyRange(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int, length: Int): Unit + def copyRange(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int, length: Int): Unit /** * Allocates a Buffer that can hold up to 'length' elements. * All elements of the buffer should be considered invalid until data is explicitly copied in. */ - protected def allocate(length: Int): Buffer + def allocate(length: Int): Buffer } /** @@ -67,9 +85,9 @@ private[spark] trait SortDataFormat[K, Buffer] extends Any { private[spark] class KVArraySortDataFormat[K, T <: AnyRef : ClassTag] extends SortDataFormat[K, Array[T]] { - override protected def getKey(data: Array[T], pos: Int): K = data(2 * pos).asInstanceOf[K] + override def getKey(data: Array[T], pos: Int): K = data(2 * pos).asInstanceOf[K] - override protected def swap(data: Array[T], pos0: Int, pos1: Int) { + override def swap(data: Array[T], pos0: Int, pos1: Int) { val tmpKey = data(2 * pos0) val tmpVal = data(2 * pos0 + 1) data(2 * pos0) = data(2 * pos1) @@ -78,17 +96,16 @@ class KVArraySortDataFormat[K, T <: AnyRef : ClassTag] extends SortDataFormat[K, data(2 * pos1 + 1) = tmpVal } - override protected def copyElement(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int) { + override def copyElement(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int) { dst(2 * dstPos) = src(2 * srcPos) dst(2 * dstPos + 1) = src(2 * srcPos + 1) } - override protected def copyRange(src: Array[T], srcPos: Int, - dst: Array[T], dstPos: Int, length: Int) { + override def copyRange(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int, length: Int) { System.arraycopy(src, 2 * srcPos, dst, 2 * dstPos, 2 * length) } - override protected def allocate(length: Int): Array[T] = { + override def allocate(length: Int): Array[T] = { new Array[T](2 * length) } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/Sorter.scala b/core/src/main/scala/org/apache/spark/util/collection/Sorter.scala new file mode 100644 index 0000000000..39f66b8c42 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/Sorter.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.util.Comparator + +/** + * A simple wrapper over the Java implementation [[TimSort]]. + * + * The Java implementation is package private, and hence it cannot be called outside package + * org.apache.spark.util.collection. This is a simple wrapper of it that is available to spark. + */ +private[spark] +class Sorter[K, Buffer](private val s: SortDataFormat[K, Buffer]) { + + private val timSort = new TimSort(s) + + /** + * Sorts the input buffer within range [lo, hi). + */ + def sort(a: Buffer, lo: Int, hi: Int, c: Comparator[_ >: K]): Unit = { + timSort.sort(a, lo, hi, c) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala index 55b5713706..467b890fb4 100644 --- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala @@ -96,13 +96,9 @@ private[spark] object XORShiftRandom { xorRand.nextInt() } - val iters = timeIt(numIters)(_) - /* Return results as a map instead of just printing to screen in case the user wants to do something with them */ - Map("javaTime" -> iters {javaRand.nextInt()}, - "xorTime" -> iters {xorRand.nextInt()}) - + Map("javaTime" -> timeIt(numIters) { javaRand.nextInt() }, + "xorTime" -> timeIt(numIters) { xorRand.nextInt() }) } - } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 65579bb9af..1c112334cc 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -351,4 +351,15 @@ class UtilsSuite extends FunSuite { outFile.delete() } } + + test("timeIt with prepare") { + var cnt = 0 + val prepare = () => { + cnt += 1 + Thread.sleep(1000) + } + val time = Utils.timeIt(2)({}, Some(prepare)) + require(cnt === 2, "prepare should be called twice") + require(time < 500, "preparation time should not count") + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala index 6fe1079c27..066d47c46a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.util.collection -import java.lang.{Float => JFloat} +import java.lang.{Float => JFloat, Integer => JInteger} import java.util.{Arrays, Comparator} import org.scalatest.FunSuite @@ -30,11 +30,15 @@ class SorterSuite extends FunSuite { val rand = new XORShiftRandom(123) val data0 = Array.tabulate[Int](10000) { i => rand.nextInt() } val data1 = data0.clone() + val data2 = data0.clone() Arrays.sort(data0) new Sorter(new IntArraySortDataFormat).sort(data1, 0, data1.length, Ordering.Int) + new Sorter(new KeyReuseIntArraySortDataFormat) + .sort(data2, 0, data2.length, Ordering[IntWrapper]) - data0.zip(data1).foreach { case (x, y) => assert(x === y) } + assert(data0.view === data1.view) + assert(data0.view === data2.view) } test("KVArraySorter") { @@ -61,10 +65,33 @@ class SorterSuite extends FunSuite { } } + /** Runs an experiment several times. */ + def runExperiment(name: String, skip: Boolean = false)(f: => Unit, prepare: () => Unit): Unit = { + if (skip) { + println(s"Skipped experiment $name.") + return + } + + val firstTry = org.apache.spark.util.Utils.timeIt(1)(f, Some(prepare)) + System.gc() + + var i = 0 + var next10: Long = 0 + while (i < 10) { + val time = org.apache.spark.util.Utils.timeIt(1)(f, Some(prepare)) + next10 += time + println(s"$name: Took $time ms") + i += 1 + } + + println(s"$name: ($firstTry ms first try, ${next10 / 10} ms average)") + } + /** * This provides a simple benchmark for comparing the Sorter with Java internal sorting. * Ideally these would be executed one at a time, each in their own JVM, so their listing - * here is mainly to have the code. + * here is mainly to have the code. Running multiple tests within the same JVM session would + * prevent JIT inlining overridden methods and hence hurt the performance. * * The goal of this code is to sort an array of key-value pairs, where the array physically * has the keys and values alternating. The basic Java sorts work only on the keys, so the @@ -72,96 +99,167 @@ class SorterSuite extends FunSuite { * those, while the Sorter approach can work directly on the input data format. * * Note that the Java implementation varies tremendously between Java 6 and Java 7, when - * the Java sort changed from merge sort to Timsort. + * the Java sort changed from merge sort to TimSort. */ - ignore("Sorter benchmark") { - - /** Runs an experiment several times. */ - def runExperiment(name: String)(f: => Unit): Unit = { - val firstTry = org.apache.spark.util.Utils.timeIt(1)(f) - System.gc() - - var i = 0 - var next10: Long = 0 - while (i < 10) { - val time = org.apache.spark.util.Utils.timeIt(1)(f) - next10 += time - println(s"$name: Took $time ms") - i += 1 - } - - println(s"$name: ($firstTry ms first try, ${next10 / 10} ms average)") - } - + ignore("Sorter benchmark for key-value pairs") { val numElements = 25000000 // 25 mil val rand = new XORShiftRandom(123) - val keys = Array.tabulate[JFloat](numElements) { i => - new JFloat(rand.nextFloat()) + // Test our key-value pairs where each element is a Tuple2[Float, Integer]. + + val kvTuples = Array.tabulate(numElements) { i => + (new JFloat(rand.nextFloat()), new JInteger(i)) } - // Test our key-value pairs where each element is a Tuple2[Float, Integer) - val kvTupleArray = Array.tabulate[AnyRef](numElements) { i => - (keys(i / 2): Float, i / 2: Int) + val kvTupleArray = new Array[AnyRef](numElements) + val prepareKvTupleArray = () => { + System.arraycopy(kvTuples, 0, kvTupleArray, 0, numElements) } - runExperiment("Tuple-sort using Arrays.sort()") { + runExperiment("Tuple-sort using Arrays.sort()")({ Arrays.sort(kvTupleArray, new Comparator[AnyRef] { override def compare(x: AnyRef, y: AnyRef): Int = - Ordering.Float.compare(x.asInstanceOf[(Float, _)]._1, y.asInstanceOf[(Float, _)]._1) + x.asInstanceOf[(JFloat, _)]._1.compareTo(y.asInstanceOf[(JFloat, _)]._1) }) - } + }, prepareKvTupleArray) // Test our Sorter where each element alternates between Float and Integer, non-primitive - val keyValueArray = Array.tabulate[AnyRef](numElements * 2) { i => - if (i % 2 == 0) keys(i / 2) else new Integer(i / 2) + + val keyValues = { + val data = new Array[AnyRef](numElements * 2) + var i = 0 + while (i < numElements) { + data(2 * i) = kvTuples(i)._1 + data(2 * i + 1) = kvTuples(i)._2 + i += 1 + } + data } + + val keyValueArray = new Array[AnyRef](numElements * 2) + val prepareKeyValueArray = () => { + System.arraycopy(keyValues, 0, keyValueArray, 0, numElements * 2) + } + val sorter = new Sorter(new KVArraySortDataFormat[JFloat, AnyRef]) - runExperiment("KV-sort using Sorter") { - sorter.sort(keyValueArray, 0, keys.length, new Comparator[JFloat] { - override def compare(x: JFloat, y: JFloat): Int = Ordering.Float.compare(x, y) + runExperiment("KV-sort using Sorter")({ + sorter.sort(keyValueArray, 0, numElements, new Comparator[JFloat] { + override def compare(x: JFloat, y: JFloat): Int = x.compareTo(y) }) + }, prepareKeyValueArray) + } + + /** + * Tests for sorting with primitive keys with/without key reuse. Java's Arrays.sort is used as + * reference, which is expected to be faster but it can only sort a single array. Sorter can be + * used to sort parallel arrays. + * + * Ideally these would be executed one at a time, each in their own JVM, so their listing + * here is mainly to have the code. Running multiple tests within the same JVM session would + * prevent JIT inlining overridden methods and hence hurt the performance. + */ + test("Sorter benchmark for primitive int array") { + val numElements = 25000000 // 25 mil + val rand = new XORShiftRandom(123) + + val ints = Array.fill(numElements)(rand.nextInt()) + val intObjects = { + val data = new Array[JInteger](numElements) + var i = 0 + while (i < numElements) { + data(i) = new JInteger(ints(i)) + i += 1 + } + data } - // Test non-primitive sort on float array - runExperiment("Java Arrays.sort()") { - Arrays.sort(keys, new Comparator[JFloat] { - override def compare(x: JFloat, y: JFloat): Int = Ordering.Float.compare(x, y) - }) + val intObjectArray = new Array[JInteger](numElements) + val prepareIntObjectArray = () => { + System.arraycopy(intObjects, 0, intObjectArray, 0, numElements) } - // Test primitive sort on float array - val primitiveKeys = Array.tabulate[Float](numElements) { i => rand.nextFloat() } - runExperiment("Java Arrays.sort() on primitive keys") { - Arrays.sort(primitiveKeys) + runExperiment("Java Arrays.sort() on non-primitive int array")({ + Arrays.sort(intObjectArray, new Comparator[JInteger] { + override def compare(x: JInteger, y: JInteger): Int = x.compareTo(y) + }) + }, prepareIntObjectArray) + + val intPrimitiveArray = new Array[Int](numElements) + val prepareIntPrimitiveArray = () => { + System.arraycopy(ints, 0, intPrimitiveArray, 0, numElements) } - } -} + runExperiment("Java Arrays.sort() on primitive int array")({ + Arrays.sort(intPrimitiveArray) + }, prepareIntPrimitiveArray) -/** Format to sort a simple Array[Int]. Could be easily generified and specialized. */ -class IntArraySortDataFormat extends SortDataFormat[Int, Array[Int]] { - override protected def getKey(data: Array[Int], pos: Int): Int = { - data(pos) + val sorterWithoutKeyReuse = new Sorter(new IntArraySortDataFormat) + runExperiment("Sorter without key reuse on primitive int array")({ + sorterWithoutKeyReuse.sort(intPrimitiveArray, 0, numElements, Ordering[Int]) + }, prepareIntPrimitiveArray) + + val sorterWithKeyReuse = new Sorter(new KeyReuseIntArraySortDataFormat) + runExperiment("Sorter with key reuse on primitive int array")({ + sorterWithKeyReuse.sort(intPrimitiveArray, 0, numElements, Ordering[IntWrapper]) + }, prepareIntPrimitiveArray) } +} - override protected def swap(data: Array[Int], pos0: Int, pos1: Int): Unit = { +abstract class AbstractIntArraySortDataFormat[K] extends SortDataFormat[K, Array[Int]] { + + override def swap(data: Array[Int], pos0: Int, pos1: Int): Unit = { val tmp = data(pos0) data(pos0) = data(pos1) data(pos1) = tmp } - override protected def copyElement(src: Array[Int], srcPos: Int, dst: Array[Int], dstPos: Int) { + override def copyElement(src: Array[Int], srcPos: Int, dst: Array[Int], dstPos: Int) { dst(dstPos) = src(srcPos) } /** Copy a range of elements starting at src(srcPos) to dest, starting at destPos. */ - override protected def copyRange(src: Array[Int], srcPos: Int, - dst: Array[Int], dstPos: Int, length: Int) { + override def copyRange(src: Array[Int], srcPos: Int, dst: Array[Int], dstPos: Int, length: Int) { System.arraycopy(src, srcPos, dst, dstPos, length) } /** Allocates a new structure that can hold up to 'length' elements. */ - override protected def allocate(length: Int): Array[Int] = { + override def allocate(length: Int): Array[Int] = { new Array[Int](length) } } + +/** Format to sort a simple Array[Int]. Could be easily generified and specialized. */ +class IntArraySortDataFormat extends AbstractIntArraySortDataFormat[Int] { + + override protected def getKey(data: Array[Int], pos: Int): Int = { + data(pos) + } +} + +/** Wrapper of Int for key reuse. */ +class IntWrapper(var key: Int = 0) extends Ordered[IntWrapper] { + + override def compare(that: IntWrapper): Int = { + Ordering.Int.compare(key, that.key) + } +} + +/** SortDataFormat for Array[Int] with reused keys. */ +class KeyReuseIntArraySortDataFormat extends AbstractIntArraySortDataFormat[IntWrapper] { + + override def newKey(): IntWrapper = { + new IntWrapper() + } + + override def getKey(data: Array[Int], pos: Int, reuse: IntWrapper): IntWrapper = { + if (reuse == null) { + new IntWrapper(data(pos)) + } else { + reuse.key = data(pos) + reuse + } + } + + override protected def getKey(data: Array[Int], pos: Int): IntWrapper = { + getKey(data, pos, null) + } +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index c58666af84..95152b58e2 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -53,7 +53,9 @@ object MimaExcludes { "org.apache.spark.scheduler.MapStatus"), // TaskContext was promoted to Abstract class ProblemFilters.exclude[AbstractClassProblem]( - "org.apache.spark.TaskContext") + "org.apache.spark.TaskContext"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.util.collection.SortDataFormat") ) ++ Seq( // Adding new methods to the JavaRDDLike trait: ProblemFilters.exclude[MissingMethodProblem]( |