aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/TimSort.java (renamed from core/src/main/java/org/apache/spark/util/collection/Sorter.java)77
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala41
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/Sorter.scala39
-rw-r--r--core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/util/UtilsSuite.scala11
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala210
-rw-r--r--project/MimaExcludes.scala4
8 files changed, 310 insertions, 106 deletions
diff --git a/core/src/main/java/org/apache/spark/util/collection/Sorter.java b/core/src/main/java/org/apache/spark/util/collection/TimSort.java
index 64ad18c0e4..409e1a41c5 100644
--- a/core/src/main/java/org/apache/spark/util/collection/Sorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/TimSort.java
@@ -20,18 +20,25 @@ package org.apache.spark.util.collection;
import java.util.Comparator;
/**
- * A port of the Android Timsort class, which utilizes a "stable, adaptive, iterative mergesort."
+ * A port of the Android TimSort class, which utilizes a "stable, adaptive, iterative mergesort."
* See the method comment on sort() for more details.
*
* This has been kept in Java with the original style in order to match very closely with the
- * Anroid source code, and thus be easy to verify correctness.
+ * Android source code, and thus be easy to verify correctness. The class is package private. We put
+ * a simple Scala wrapper {@link org.apache.spark.util.collection.Sorter}, which is available to
+ * package org.apache.spark.
*
* The purpose of the port is to generalize the interface to the sort to accept input data formats
* besides simple arrays where every element is sorted individually. For instance, the AppendOnlyMap
* uses this to sort an Array with alternating elements of the form [key, value, key, value].
* This generalization comes with minimal overhead -- see SortDataFormat for more information.
+ *
+ * We allow key reuse to prevent creating many key objects -- see SortDataFormat.
+ *
+ * @see org.apache.spark.util.collection.SortDataFormat
+ * @see org.apache.spark.util.collection.Sorter
*/
-class Sorter<K, Buffer> {
+class TimSort<K, Buffer> {
/**
* This is the minimum sized sequence that will be merged. Shorter
@@ -54,7 +61,7 @@ class Sorter<K, Buffer> {
private final SortDataFormat<K, Buffer> s;
- public Sorter(SortDataFormat<K, Buffer> sortDataFormat) {
+ public TimSort(SortDataFormat<K, Buffer> sortDataFormat) {
this.s = sortDataFormat;
}
@@ -91,7 +98,7 @@ class Sorter<K, Buffer> {
*
* @author Josh Bloch
*/
- void sort(Buffer a, int lo, int hi, Comparator<? super K> c) {
+ public void sort(Buffer a, int lo, int hi, Comparator<? super K> c) {
assert c != null;
int nRemaining = hi - lo;
@@ -162,10 +169,13 @@ class Sorter<K, Buffer> {
if (start == lo)
start++;
+ K key0 = s.newKey();
+ K key1 = s.newKey();
+
Buffer pivotStore = s.allocate(1);
for ( ; start < hi; start++) {
s.copyElement(a, start, pivotStore, 0);
- K pivot = s.getKey(pivotStore, 0);
+ K pivot = s.getKey(pivotStore, 0, key0);
// Set left (and right) to the index where a[start] (pivot) belongs
int left = lo;
@@ -178,7 +188,7 @@ class Sorter<K, Buffer> {
*/
while (left < right) {
int mid = (left + right) >>> 1;
- if (c.compare(pivot, s.getKey(a, mid)) < 0)
+ if (c.compare(pivot, s.getKey(a, mid, key1)) < 0)
right = mid;
else
left = mid + 1;
@@ -235,13 +245,16 @@ class Sorter<K, Buffer> {
if (runHi == hi)
return 1;
+ K key0 = s.newKey();
+ K key1 = s.newKey();
+
// Find end of run, and reverse range if descending
- if (c.compare(s.getKey(a, runHi++), s.getKey(a, lo)) < 0) { // Descending
- while (runHi < hi && c.compare(s.getKey(a, runHi), s.getKey(a, runHi - 1)) < 0)
+ if (c.compare(s.getKey(a, runHi++, key0), s.getKey(a, lo, key1)) < 0) { // Descending
+ while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) < 0)
runHi++;
reverseRange(a, lo, runHi);
} else { // Ascending
- while (runHi < hi && c.compare(s.getKey(a, runHi), s.getKey(a, runHi - 1)) >= 0)
+ while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0)
runHi++;
}
@@ -468,11 +481,13 @@ class Sorter<K, Buffer> {
}
stackSize--;
+ K key0 = s.newKey();
+
/*
* Find where the first element of run2 goes in run1. Prior elements
* in run1 can be ignored (because they're already in place).
*/
- int k = gallopRight(s.getKey(a, base2), a, base1, len1, 0, c);
+ int k = gallopRight(s.getKey(a, base2, key0), a, base1, len1, 0, c);
assert k >= 0;
base1 += k;
len1 -= k;
@@ -483,7 +498,7 @@ class Sorter<K, Buffer> {
* Find where the last element of run1 goes in run2. Subsequent elements
* in run2 can be ignored (because they're already in place).
*/
- len2 = gallopLeft(s.getKey(a, base1 + len1 - 1), a, base2, len2, len2 - 1, c);
+ len2 = gallopLeft(s.getKey(a, base1 + len1 - 1, key0), a, base2, len2, len2 - 1, c);
assert len2 >= 0;
if (len2 == 0)
return;
@@ -517,10 +532,12 @@ class Sorter<K, Buffer> {
assert len > 0 && hint >= 0 && hint < len;
int lastOfs = 0;
int ofs = 1;
- if (c.compare(key, s.getKey(a, base + hint)) > 0) {
+ K key0 = s.newKey();
+
+ if (c.compare(key, s.getKey(a, base + hint, key0)) > 0) {
// Gallop right until a[base+hint+lastOfs] < key <= a[base+hint+ofs]
int maxOfs = len - hint;
- while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs)) > 0) {
+ while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key0)) > 0) {
lastOfs = ofs;
ofs = (ofs << 1) + 1;
if (ofs <= 0) // int overflow
@@ -535,7 +552,7 @@ class Sorter<K, Buffer> {
} else { // key <= a[base + hint]
// Gallop left until a[base+hint-ofs] < key <= a[base+hint-lastOfs]
final int maxOfs = hint + 1;
- while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs)) <= 0) {
+ while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key0)) <= 0) {
lastOfs = ofs;
ofs = (ofs << 1) + 1;
if (ofs <= 0) // int overflow
@@ -560,7 +577,7 @@ class Sorter<K, Buffer> {
while (lastOfs < ofs) {
int m = lastOfs + ((ofs - lastOfs) >>> 1);
- if (c.compare(key, s.getKey(a, base + m)) > 0)
+ if (c.compare(key, s.getKey(a, base + m, key0)) > 0)
lastOfs = m + 1; // a[base + m] < key
else
ofs = m; // key <= a[base + m]
@@ -587,10 +604,12 @@ class Sorter<K, Buffer> {
int ofs = 1;
int lastOfs = 0;
- if (c.compare(key, s.getKey(a, base + hint)) < 0) {
+ K key1 = s.newKey();
+
+ if (c.compare(key, s.getKey(a, base + hint, key1)) < 0) {
// Gallop left until a[b+hint - ofs] <= key < a[b+hint - lastOfs]
int maxOfs = hint + 1;
- while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs)) < 0) {
+ while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key1)) < 0) {
lastOfs = ofs;
ofs = (ofs << 1) + 1;
if (ofs <= 0) // int overflow
@@ -606,7 +625,7 @@ class Sorter<K, Buffer> {
} else { // a[b + hint] <= key
// Gallop right until a[b+hint + lastOfs] <= key < a[b+hint + ofs]
int maxOfs = len - hint;
- while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs)) >= 0) {
+ while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key1)) >= 0) {
lastOfs = ofs;
ofs = (ofs << 1) + 1;
if (ofs <= 0) // int overflow
@@ -630,7 +649,7 @@ class Sorter<K, Buffer> {
while (lastOfs < ofs) {
int m = lastOfs + ((ofs - lastOfs) >>> 1);
- if (c.compare(key, s.getKey(a, base + m)) < 0)
+ if (c.compare(key, s.getKey(a, base + m, key1)) < 0)
ofs = m; // key < a[b + m]
else
lastOfs = m + 1; // a[b + m] <= key
@@ -679,6 +698,9 @@ class Sorter<K, Buffer> {
return;
}
+ K key0 = s.newKey();
+ K key1 = s.newKey();
+
Comparator<? super K> c = this.c; // Use local variable for performance
int minGallop = this.minGallop; // " " " " "
outer:
@@ -692,7 +714,7 @@ class Sorter<K, Buffer> {
*/
do {
assert len1 > 1 && len2 > 0;
- if (c.compare(s.getKey(a, cursor2), s.getKey(tmp, cursor1)) < 0) {
+ if (c.compare(s.getKey(a, cursor2, key0), s.getKey(tmp, cursor1, key1)) < 0) {
s.copyElement(a, cursor2++, a, dest++);
count2++;
count1 = 0;
@@ -714,7 +736,7 @@ class Sorter<K, Buffer> {
*/
do {
assert len1 > 1 && len2 > 0;
- count1 = gallopRight(s.getKey(a, cursor2), tmp, cursor1, len1, 0, c);
+ count1 = gallopRight(s.getKey(a, cursor2, key0), tmp, cursor1, len1, 0, c);
if (count1 != 0) {
s.copyRange(tmp, cursor1, a, dest, count1);
dest += count1;
@@ -727,7 +749,7 @@ class Sorter<K, Buffer> {
if (--len2 == 0)
break outer;
- count2 = gallopLeft(s.getKey(tmp, cursor1), a, cursor2, len2, 0, c);
+ count2 = gallopLeft(s.getKey(tmp, cursor1, key0), a, cursor2, len2, 0, c);
if (count2 != 0) {
s.copyRange(a, cursor2, a, dest, count2);
dest += count2;
@@ -784,6 +806,9 @@ class Sorter<K, Buffer> {
int cursor2 = len2 - 1; // Indexes into tmp array
int dest = base2 + len2 - 1; // Indexes into a
+ K key0 = s.newKey();
+ K key1 = s.newKey();
+
// Move last element of first run and deal with degenerate cases
s.copyElement(a, cursor1--, a, dest--);
if (--len1 == 0) {
@@ -811,7 +836,7 @@ class Sorter<K, Buffer> {
*/
do {
assert len1 > 0 && len2 > 1;
- if (c.compare(s.getKey(tmp, cursor2), s.getKey(a, cursor1)) < 0) {
+ if (c.compare(s.getKey(tmp, cursor2, key0), s.getKey(a, cursor1, key1)) < 0) {
s.copyElement(a, cursor1--, a, dest--);
count1++;
count2 = 0;
@@ -833,7 +858,7 @@ class Sorter<K, Buffer> {
*/
do {
assert len1 > 0 && len2 > 1;
- count1 = len1 - gallopRight(s.getKey(tmp, cursor2), a, base1, len1, len1 - 1, c);
+ count1 = len1 - gallopRight(s.getKey(tmp, cursor2, key0), a, base1, len1, len1 - 1, c);
if (count1 != 0) {
dest -= count1;
cursor1 -= count1;
@@ -846,7 +871,7 @@ class Sorter<K, Buffer> {
if (--len2 == 1)
break outer;
- count2 = len2 - gallopLeft(s.getKey(a, cursor1), tmp, 0, len2, len2 - 1, c);
+ count2 = len2 - gallopLeft(s.getKey(a, cursor1, key0), tmp, 0, len2, len2 - 1, c);
if (count2 != 0) {
dest -= count2;
cursor2 -= count2;
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 612eca308b..1e881da511 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -1272,12 +1272,28 @@ private[spark] object Utils extends Logging {
/**
* Timing method based on iterations that permit JVM JIT optimization.
* @param numIters number of iterations
- * @param f function to be executed
+ * @param f function to be executed. If prepare is not None, the running time of each call to f
+ * must be an order of magnitude longer than one millisecond for accurate timing.
+ * @param prepare function to be executed before each call to f. Its running time doesn't count.
+ * @return the total time across all iterations (not couting preparation time)
*/
- def timeIt(numIters: Int)(f: => Unit): Long = {
- val start = System.currentTimeMillis
- times(numIters)(f)
- System.currentTimeMillis - start
+ def timeIt(numIters: Int)(f: => Unit, prepare: Option[() => Unit] = None): Long = {
+ if (prepare.isEmpty) {
+ val start = System.currentTimeMillis
+ times(numIters)(f)
+ System.currentTimeMillis - start
+ } else {
+ var i = 0
+ var sum = 0L
+ while (i < numIters) {
+ prepare.get.apply()
+ val start = System.currentTimeMillis
+ f
+ sum += System.currentTimeMillis - start
+ i += 1
+ }
+ sum
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala b/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala
index ac1528969f..4f0bf8384a 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala
@@ -27,33 +27,51 @@ import scala.reflect.ClassTag
* Example format: an array of numbers, where each element is also the key.
* See [[KVArraySortDataFormat]] for a more exciting format.
*
- * This trait extends Any to ensure it is universal (and thus compiled to a Java interface).
+ * Note: Declaring and instantiating multiple subclasses of this class would prevent JIT inlining
+ * overridden methods and hence decrease the shuffle performance.
*
* @tparam K Type of the sort key of each element
* @tparam Buffer Internal data structure used by a particular format (e.g., Array[Int]).
*/
// TODO: Making Buffer a real trait would be a better abstraction, but adds some complexity.
-private[spark] trait SortDataFormat[K, Buffer] extends Any {
+private[spark]
+abstract class SortDataFormat[K, Buffer] {
+
+ /**
+ * Creates a new mutable key for reuse. This should be implemented if you want to override
+ * [[getKey(Buffer, Int, K)]].
+ */
+ def newKey(): K = null.asInstanceOf[K]
+
/** Return the sort key for the element at the given index. */
protected def getKey(data: Buffer, pos: Int): K
+ /**
+ * Returns the sort key for the element at the given index and reuse the input key if possible.
+ * The default implementation ignores the reuse parameter and invokes [[getKey(Buffer, Int]].
+ * If you want to override this method, you must implement [[newKey()]].
+ */
+ def getKey(data: Buffer, pos: Int, reuse: K): K = {
+ getKey(data, pos)
+ }
+
/** Swap two elements. */
- protected def swap(data: Buffer, pos0: Int, pos1: Int): Unit
+ def swap(data: Buffer, pos0: Int, pos1: Int): Unit
/** Copy a single element from src(srcPos) to dst(dstPos). */
- protected def copyElement(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int): Unit
+ def copyElement(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int): Unit
/**
* Copy a range of elements starting at src(srcPos) to dst, starting at dstPos.
* Overlapping ranges are allowed.
*/
- protected def copyRange(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int, length: Int): Unit
+ def copyRange(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int, length: Int): Unit
/**
* Allocates a Buffer that can hold up to 'length' elements.
* All elements of the buffer should be considered invalid until data is explicitly copied in.
*/
- protected def allocate(length: Int): Buffer
+ def allocate(length: Int): Buffer
}
/**
@@ -67,9 +85,9 @@ private[spark] trait SortDataFormat[K, Buffer] extends Any {
private[spark]
class KVArraySortDataFormat[K, T <: AnyRef : ClassTag] extends SortDataFormat[K, Array[T]] {
- override protected def getKey(data: Array[T], pos: Int): K = data(2 * pos).asInstanceOf[K]
+ override def getKey(data: Array[T], pos: Int): K = data(2 * pos).asInstanceOf[K]
- override protected def swap(data: Array[T], pos0: Int, pos1: Int) {
+ override def swap(data: Array[T], pos0: Int, pos1: Int) {
val tmpKey = data(2 * pos0)
val tmpVal = data(2 * pos0 + 1)
data(2 * pos0) = data(2 * pos1)
@@ -78,17 +96,16 @@ class KVArraySortDataFormat[K, T <: AnyRef : ClassTag] extends SortDataFormat[K,
data(2 * pos1 + 1) = tmpVal
}
- override protected def copyElement(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int) {
+ override def copyElement(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int) {
dst(2 * dstPos) = src(2 * srcPos)
dst(2 * dstPos + 1) = src(2 * srcPos + 1)
}
- override protected def copyRange(src: Array[T], srcPos: Int,
- dst: Array[T], dstPos: Int, length: Int) {
+ override def copyRange(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int, length: Int) {
System.arraycopy(src, 2 * srcPos, dst, 2 * dstPos, 2 * length)
}
- override protected def allocate(length: Int): Array[T] = {
+ override def allocate(length: Int): Array[T] = {
new Array[T](2 * length)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/Sorter.scala b/core/src/main/scala/org/apache/spark/util/collection/Sorter.scala
new file mode 100644
index 0000000000..39f66b8c42
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/Sorter.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.util.Comparator
+
+/**
+ * A simple wrapper over the Java implementation [[TimSort]].
+ *
+ * The Java implementation is package private, and hence it cannot be called outside package
+ * org.apache.spark.util.collection. This is a simple wrapper of it that is available to spark.
+ */
+private[spark]
+class Sorter[K, Buffer](private val s: SortDataFormat[K, Buffer]) {
+
+ private val timSort = new TimSort(s)
+
+ /**
+ * Sorts the input buffer within range [lo, hi).
+ */
+ def sort(a: Buffer, lo: Int, hi: Int, c: Comparator[_ >: K]): Unit = {
+ timSort.sort(a, lo, hi, c)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
index 55b5713706..467b890fb4 100644
--- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
@@ -96,13 +96,9 @@ private[spark] object XORShiftRandom {
xorRand.nextInt()
}
- val iters = timeIt(numIters)(_)
-
/* Return results as a map instead of just printing to screen
in case the user wants to do something with them */
- Map("javaTime" -> iters {javaRand.nextInt()},
- "xorTime" -> iters {xorRand.nextInt()})
-
+ Map("javaTime" -> timeIt(numIters) { javaRand.nextInt() },
+ "xorTime" -> timeIt(numIters) { xorRand.nextInt() })
}
-
}
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index 65579bb9af..1c112334cc 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -351,4 +351,15 @@ class UtilsSuite extends FunSuite {
outFile.delete()
}
}
+
+ test("timeIt with prepare") {
+ var cnt = 0
+ val prepare = () => {
+ cnt += 1
+ Thread.sleep(1000)
+ }
+ val time = Utils.timeIt(2)({}, Some(prepare))
+ require(cnt === 2, "prepare should be called twice")
+ require(time < 500, "preparation time should not count")
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala
index 6fe1079c27..066d47c46a 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.util.collection
-import java.lang.{Float => JFloat}
+import java.lang.{Float => JFloat, Integer => JInteger}
import java.util.{Arrays, Comparator}
import org.scalatest.FunSuite
@@ -30,11 +30,15 @@ class SorterSuite extends FunSuite {
val rand = new XORShiftRandom(123)
val data0 = Array.tabulate[Int](10000) { i => rand.nextInt() }
val data1 = data0.clone()
+ val data2 = data0.clone()
Arrays.sort(data0)
new Sorter(new IntArraySortDataFormat).sort(data1, 0, data1.length, Ordering.Int)
+ new Sorter(new KeyReuseIntArraySortDataFormat)
+ .sort(data2, 0, data2.length, Ordering[IntWrapper])
- data0.zip(data1).foreach { case (x, y) => assert(x === y) }
+ assert(data0.view === data1.view)
+ assert(data0.view === data2.view)
}
test("KVArraySorter") {
@@ -61,10 +65,33 @@ class SorterSuite extends FunSuite {
}
}
+ /** Runs an experiment several times. */
+ def runExperiment(name: String, skip: Boolean = false)(f: => Unit, prepare: () => Unit): Unit = {
+ if (skip) {
+ println(s"Skipped experiment $name.")
+ return
+ }
+
+ val firstTry = org.apache.spark.util.Utils.timeIt(1)(f, Some(prepare))
+ System.gc()
+
+ var i = 0
+ var next10: Long = 0
+ while (i < 10) {
+ val time = org.apache.spark.util.Utils.timeIt(1)(f, Some(prepare))
+ next10 += time
+ println(s"$name: Took $time ms")
+ i += 1
+ }
+
+ println(s"$name: ($firstTry ms first try, ${next10 / 10} ms average)")
+ }
+
/**
* This provides a simple benchmark for comparing the Sorter with Java internal sorting.
* Ideally these would be executed one at a time, each in their own JVM, so their listing
- * here is mainly to have the code.
+ * here is mainly to have the code. Running multiple tests within the same JVM session would
+ * prevent JIT inlining overridden methods and hence hurt the performance.
*
* The goal of this code is to sort an array of key-value pairs, where the array physically
* has the keys and values alternating. The basic Java sorts work only on the keys, so the
@@ -72,96 +99,167 @@ class SorterSuite extends FunSuite {
* those, while the Sorter approach can work directly on the input data format.
*
* Note that the Java implementation varies tremendously between Java 6 and Java 7, when
- * the Java sort changed from merge sort to Timsort.
+ * the Java sort changed from merge sort to TimSort.
*/
- ignore("Sorter benchmark") {
-
- /** Runs an experiment several times. */
- def runExperiment(name: String)(f: => Unit): Unit = {
- val firstTry = org.apache.spark.util.Utils.timeIt(1)(f)
- System.gc()
-
- var i = 0
- var next10: Long = 0
- while (i < 10) {
- val time = org.apache.spark.util.Utils.timeIt(1)(f)
- next10 += time
- println(s"$name: Took $time ms")
- i += 1
- }
-
- println(s"$name: ($firstTry ms first try, ${next10 / 10} ms average)")
- }
-
+ ignore("Sorter benchmark for key-value pairs") {
val numElements = 25000000 // 25 mil
val rand = new XORShiftRandom(123)
- val keys = Array.tabulate[JFloat](numElements) { i =>
- new JFloat(rand.nextFloat())
+ // Test our key-value pairs where each element is a Tuple2[Float, Integer].
+
+ val kvTuples = Array.tabulate(numElements) { i =>
+ (new JFloat(rand.nextFloat()), new JInteger(i))
}
- // Test our key-value pairs where each element is a Tuple2[Float, Integer)
- val kvTupleArray = Array.tabulate[AnyRef](numElements) { i =>
- (keys(i / 2): Float, i / 2: Int)
+ val kvTupleArray = new Array[AnyRef](numElements)
+ val prepareKvTupleArray = () => {
+ System.arraycopy(kvTuples, 0, kvTupleArray, 0, numElements)
}
- runExperiment("Tuple-sort using Arrays.sort()") {
+ runExperiment("Tuple-sort using Arrays.sort()")({
Arrays.sort(kvTupleArray, new Comparator[AnyRef] {
override def compare(x: AnyRef, y: AnyRef): Int =
- Ordering.Float.compare(x.asInstanceOf[(Float, _)]._1, y.asInstanceOf[(Float, _)]._1)
+ x.asInstanceOf[(JFloat, _)]._1.compareTo(y.asInstanceOf[(JFloat, _)]._1)
})
- }
+ }, prepareKvTupleArray)
// Test our Sorter where each element alternates between Float and Integer, non-primitive
- val keyValueArray = Array.tabulate[AnyRef](numElements * 2) { i =>
- if (i % 2 == 0) keys(i / 2) else new Integer(i / 2)
+
+ val keyValues = {
+ val data = new Array[AnyRef](numElements * 2)
+ var i = 0
+ while (i < numElements) {
+ data(2 * i) = kvTuples(i)._1
+ data(2 * i + 1) = kvTuples(i)._2
+ i += 1
+ }
+ data
}
+
+ val keyValueArray = new Array[AnyRef](numElements * 2)
+ val prepareKeyValueArray = () => {
+ System.arraycopy(keyValues, 0, keyValueArray, 0, numElements * 2)
+ }
+
val sorter = new Sorter(new KVArraySortDataFormat[JFloat, AnyRef])
- runExperiment("KV-sort using Sorter") {
- sorter.sort(keyValueArray, 0, keys.length, new Comparator[JFloat] {
- override def compare(x: JFloat, y: JFloat): Int = Ordering.Float.compare(x, y)
+ runExperiment("KV-sort using Sorter")({
+ sorter.sort(keyValueArray, 0, numElements, new Comparator[JFloat] {
+ override def compare(x: JFloat, y: JFloat): Int = x.compareTo(y)
})
+ }, prepareKeyValueArray)
+ }
+
+ /**
+ * Tests for sorting with primitive keys with/without key reuse. Java's Arrays.sort is used as
+ * reference, which is expected to be faster but it can only sort a single array. Sorter can be
+ * used to sort parallel arrays.
+ *
+ * Ideally these would be executed one at a time, each in their own JVM, so their listing
+ * here is mainly to have the code. Running multiple tests within the same JVM session would
+ * prevent JIT inlining overridden methods and hence hurt the performance.
+ */
+ test("Sorter benchmark for primitive int array") {
+ val numElements = 25000000 // 25 mil
+ val rand = new XORShiftRandom(123)
+
+ val ints = Array.fill(numElements)(rand.nextInt())
+ val intObjects = {
+ val data = new Array[JInteger](numElements)
+ var i = 0
+ while (i < numElements) {
+ data(i) = new JInteger(ints(i))
+ i += 1
+ }
+ data
}
- // Test non-primitive sort on float array
- runExperiment("Java Arrays.sort()") {
- Arrays.sort(keys, new Comparator[JFloat] {
- override def compare(x: JFloat, y: JFloat): Int = Ordering.Float.compare(x, y)
- })
+ val intObjectArray = new Array[JInteger](numElements)
+ val prepareIntObjectArray = () => {
+ System.arraycopy(intObjects, 0, intObjectArray, 0, numElements)
}
- // Test primitive sort on float array
- val primitiveKeys = Array.tabulate[Float](numElements) { i => rand.nextFloat() }
- runExperiment("Java Arrays.sort() on primitive keys") {
- Arrays.sort(primitiveKeys)
+ runExperiment("Java Arrays.sort() on non-primitive int array")({
+ Arrays.sort(intObjectArray, new Comparator[JInteger] {
+ override def compare(x: JInteger, y: JInteger): Int = x.compareTo(y)
+ })
+ }, prepareIntObjectArray)
+
+ val intPrimitiveArray = new Array[Int](numElements)
+ val prepareIntPrimitiveArray = () => {
+ System.arraycopy(ints, 0, intPrimitiveArray, 0, numElements)
}
- }
-}
+ runExperiment("Java Arrays.sort() on primitive int array")({
+ Arrays.sort(intPrimitiveArray)
+ }, prepareIntPrimitiveArray)
-/** Format to sort a simple Array[Int]. Could be easily generified and specialized. */
-class IntArraySortDataFormat extends SortDataFormat[Int, Array[Int]] {
- override protected def getKey(data: Array[Int], pos: Int): Int = {
- data(pos)
+ val sorterWithoutKeyReuse = new Sorter(new IntArraySortDataFormat)
+ runExperiment("Sorter without key reuse on primitive int array")({
+ sorterWithoutKeyReuse.sort(intPrimitiveArray, 0, numElements, Ordering[Int])
+ }, prepareIntPrimitiveArray)
+
+ val sorterWithKeyReuse = new Sorter(new KeyReuseIntArraySortDataFormat)
+ runExperiment("Sorter with key reuse on primitive int array")({
+ sorterWithKeyReuse.sort(intPrimitiveArray, 0, numElements, Ordering[IntWrapper])
+ }, prepareIntPrimitiveArray)
}
+}
- override protected def swap(data: Array[Int], pos0: Int, pos1: Int): Unit = {
+abstract class AbstractIntArraySortDataFormat[K] extends SortDataFormat[K, Array[Int]] {
+
+ override def swap(data: Array[Int], pos0: Int, pos1: Int): Unit = {
val tmp = data(pos0)
data(pos0) = data(pos1)
data(pos1) = tmp
}
- override protected def copyElement(src: Array[Int], srcPos: Int, dst: Array[Int], dstPos: Int) {
+ override def copyElement(src: Array[Int], srcPos: Int, dst: Array[Int], dstPos: Int) {
dst(dstPos) = src(srcPos)
}
/** Copy a range of elements starting at src(srcPos) to dest, starting at destPos. */
- override protected def copyRange(src: Array[Int], srcPos: Int,
- dst: Array[Int], dstPos: Int, length: Int) {
+ override def copyRange(src: Array[Int], srcPos: Int, dst: Array[Int], dstPos: Int, length: Int) {
System.arraycopy(src, srcPos, dst, dstPos, length)
}
/** Allocates a new structure that can hold up to 'length' elements. */
- override protected def allocate(length: Int): Array[Int] = {
+ override def allocate(length: Int): Array[Int] = {
new Array[Int](length)
}
}
+
+/** Format to sort a simple Array[Int]. Could be easily generified and specialized. */
+class IntArraySortDataFormat extends AbstractIntArraySortDataFormat[Int] {
+
+ override protected def getKey(data: Array[Int], pos: Int): Int = {
+ data(pos)
+ }
+}
+
+/** Wrapper of Int for key reuse. */
+class IntWrapper(var key: Int = 0) extends Ordered[IntWrapper] {
+
+ override def compare(that: IntWrapper): Int = {
+ Ordering.Int.compare(key, that.key)
+ }
+}
+
+/** SortDataFormat for Array[Int] with reused keys. */
+class KeyReuseIntArraySortDataFormat extends AbstractIntArraySortDataFormat[IntWrapper] {
+
+ override def newKey(): IntWrapper = {
+ new IntWrapper()
+ }
+
+ override def getKey(data: Array[Int], pos: Int, reuse: IntWrapper): IntWrapper = {
+ if (reuse == null) {
+ new IntWrapper(data(pos))
+ } else {
+ reuse.key = data(pos)
+ reuse
+ }
+ }
+
+ override protected def getKey(data: Array[Int], pos: Int): IntWrapper = {
+ getKey(data, pos, null)
+ }
+}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index c58666af84..95152b58e2 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -53,7 +53,9 @@ object MimaExcludes {
"org.apache.spark.scheduler.MapStatus"),
// TaskContext was promoted to Abstract class
ProblemFilters.exclude[AbstractClassProblem](
- "org.apache.spark.TaskContext")
+ "org.apache.spark.TaskContext"),
+ ProblemFilters.exclude[IncompatibleTemplateDefProblem](
+ "org.apache.spark.util.collection.SortDataFormat")
) ++ Seq(
// Adding new methods to the JavaRDDLike trait:
ProblemFilters.exclude[MissingMethodProblem](