summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdriaan Moors <adriaan.moors@typesafe.com>2015-06-16 15:14:38 -0700
committerAdriaan Moors <adriaan.moors@typesafe.com>2015-06-16 15:14:38 -0700
commit40b28c462e15b3372faf78df1d6503a2b134d3cb (patch)
treed0545634e38b69c39050d9527ec5481539b02f83
parentd4deb270b6ea9e6e75a535573470d173b55da9b2 (diff)
parent7d1b1292db82f33905f9a9ca214cf22f0a16591f (diff)
downloadscala-40b28c462e15b3372faf78df1d6503a2b134d3cb.tar.gz
scala-40b28c462e15b3372faf78df1d6503a2b134d3cb.tar.bz2
scala-40b28c462e15b3372faf78df1d6503a2b134d3cb.zip
Merge pull request #4557 from adriaanm/mcpasterton-2.10
Clean implementation of sorts for scala.util.Sorting.
-rw-r--r--bincompat-forward.whitelist.conf52
-rw-r--r--src/library/scala/util/Sorting.scala712
-rw-r--r--test/junit/scala/util/SortingTest.scala69
3 files changed, 356 insertions, 477 deletions
diff --git a/bincompat-forward.whitelist.conf b/bincompat-forward.whitelist.conf
index 1c532889c2..b81929c9f8 100644
--- a/bincompat-forward.whitelist.conf
+++ b/bincompat-forward.whitelist.conf
@@ -195,6 +195,58 @@ filter {
{
matchName="scala.xml.pull.ExceptionEvent$"
problemName=MissingClassProblem
+ },
+ {
+ matchName="scala.util.Sorting.scala$util$Sorting$$mergeSort$default$5"
+ problemName=MissingMethodProblem
+ },
+ {
+ matchName="scala.util.Sorting.scala$util$Sorting$$mergeSort$mBc$sp"
+ problemName=MissingMethodProblem
+ },
+ {
+ matchName="scala.util.Sorting.scala$util$Sorting$$mergeSort$mFc$sp"
+ problemName=MissingMethodProblem
+ },
+ {
+ matchName="scala.util.Sorting.scala$util$Sorting$$mergeSort$mJc$sp"
+ problemName=MissingMethodProblem
+ },
+ {
+ matchName="scala.util.Sorting.scala$util$Sorting$$mergeSort"
+ problemName=MissingMethodProblem
+ },
+ {
+ matchName="scala.util.Sorting.scala$util$Sorting$$mergeSort$mCc$sp"
+ problemName=MissingMethodProblem
+ },
+ {
+ matchName="scala.util.Sorting.scala$util$Sorting$$mergeSort$mSc$sp"
+ problemName=MissingMethodProblem
+ },
+ {
+ matchName="scala.util.Sorting.scala$util$Sorting$$insertionSort"
+ problemName=MissingMethodProblem
+ },
+ {
+ matchName="scala.util.Sorting.scala$util$Sorting$$mergeSort$mZc$sp"
+ problemName=MissingMethodProblem
+ },
+ {
+ matchName="scala.util.Sorting.scala$util$Sorting$$mergeSort$mDc$sp"
+ problemName=MissingMethodProblem
+ },
+ {
+ matchName="scala.util.Sorting.scala$util$Sorting$$mergeSort$mIc$sp"
+ problemName=MissingMethodProblem
+ },
+ {
+ matchName="scala.util.Sorting.scala$util$Sorting$$mergeSorted"
+ problemName=MissingMethodProblem
+ },
+ {
+ matchName="scala.util.Sorting.scala$util$Sorting$$booleanSort"
+ problemName=MissingMethodProblem
}
]
}
diff --git a/src/library/scala/util/Sorting.scala b/src/library/scala/util/Sorting.scala
index 276e157f55..ee2bdbc4a7 100644
--- a/src/library/scala/util/Sorting.scala
+++ b/src/library/scala/util/Sorting.scala
@@ -1,6 +1,6 @@
/* __ *\
** ________ ___ / / ___ Scala API **
-** / __/ __// _ | / / / _ | (c) 2006-2009, Ross Judson **
+** / __/ __// _ | / / / _ | (c) 2006-2015, LAMP/EPFL **
** __\ \/ /__/ __ |/ /__/ __ | http://scala-lang.org/ **
** /____/\___/_/ |_/____/_/ | | **
** |/ **
@@ -9,518 +9,276 @@
package scala
package util
-import scala.reflect.{ ClassTag, classTag }
-import scala.math.{ Ordering, max, min }
+import scala.reflect.ClassTag
+import scala.math.Ordering
-/** The Sorting object provides functions that can sort various kinds of
- * objects. You can provide a comparison function, or you can request a sort
- * of items that are viewable as [[scala.math.Ordered]]. Some sorts that
- * operate directly on a subset of value types are also provided. These
- * implementations are derived from those in the Sun JDK.
+/** The `Sorting` object provides convenience wrappers for `java.util.Arrays.sort`.
+ * Methods that defer to `java.util.Arrays.sort` say that they do or under what
+ * conditions that they do.
*
- * Note that stability doesn't matter for value types, so use the `quickSort`
- * variants for those. `stableSort` is intended to be used with
- * objects when the prior ordering should be preserved, where possible.
+ * `Sorting` also implements a general-purpose quicksort and stable (merge) sort
+ * for those cases where `java.util.Arrays.sort` could only be used at the cost
+ * of a large memory penalty. If performance rather than memory usage is the
+ * primary concern, one may wish to find alternate strategies to use
+ * `java.util.Arrays.sort` directly e.g. by boxing primitives to use
+ * a custom ordering on them.
+ *
+ * `Sorting` provides methods where you can provide a comparison function, or
+ * can request a sort of items that are [[scala.math.Ordered]] or that
+ * otherwise have an implicit or explicit [[scala.math.Ordering]].
+ *
+ * Note also that high-performance non-default sorts for numeric types
+ * are not provided. If this is required, it is advisable to investigate
+ * other libraries that cover this use case.
*
* @author Ross Judson
- * @version 1.0
+ * @author Adriaan Moors
+ * @author Rex Kerr
+ * @version 1.1
*/
object Sorting {
- /** Quickly sort an array of Doubles. */
- def quickSort(a: Array[Double]) { sort1(a, 0, a.length) }
-
- /** Quickly sort an array of items with an implicit Ordering. */
- def quickSort[K: Ordering](a: Array[K]) { sort1(a, 0, a.length) }
-
- /** Quickly sort an array of Ints. */
- def quickSort(a: Array[Int]) { sort1(a, 0, a.length) }
-
- /** Quickly sort an array of Floats. */
- def quickSort(a: Array[Float]) { sort1(a, 0, a.length) }
-
- /** Sort an array of K where K is Ordered, preserving the existing order
- * where the values are equal. */
- def stableSort[K: ClassTag: Ordering](a: Array[K]) {
- stableSort(a, 0, a.length-1, new Array[K](a.length), Ordering[K].lt _)
- }
+ /** Sort an array of Doubles using `java.util.Arrays.sort`. */
+ def quickSort(a: Array[Double]): Unit = java.util.Arrays.sort(a)
- /** Sorts an array of `K` given an ordering function `f`.
- * `f` should return `true` iff its first parameter is strictly less than its second parameter.
- */
- def stableSort[K: ClassTag](a: Array[K], f: (K, K) => Boolean) {
- stableSort(a, 0, a.length-1, new Array[K](a.length), f)
- }
+ /** Sort an array of Ints using `java.util.Arrays.sort`. */
+ def quickSort(a: Array[Int]): Unit = java.util.Arrays.sort(a)
- /** Sorts an arbitrary sequence into an array, given a comparison function
- * that should return `true` iff parameter one is strictly less than parameter two.
- *
- * @param a the sequence to be sorted.
- * @param f the comparison function.
- * @return the sorted sequence of items.
- */
- def stableSort[K: ClassTag](a: Seq[K], f: (K, K) => Boolean): Array[K] = {
- val ret = a.toArray
- stableSort(ret, f)
- ret
- }
+ /** Sort an array of Floats using `java.util.Arrays.sort`. */
+ def quickSort(a: Array[Float]): Unit = java.util.Arrays.sort(a)
+
+ private final val qsortThreshold = 16
- /** Sorts an arbitrary sequence of items that are viewable as ordered. */
- def stableSort[K: ClassTag: Ordering](a: Seq[K]): Array[K] =
- stableSort(a, Ordering[K].lt _)
-
- /** Stably sorts a sequence of items given an extraction function that will
- * return an ordered key from an item.
- *
- * @param a the sequence to be sorted.
- * @param f the comparison function.
- * @return the sorted sequence of items.
- */
- def stableSort[K: ClassTag, M: Ordering](a: Seq[K], f: K => M): Array[K] =
- stableSort(a)(implicitly[ClassTag[K]], Ordering[M] on f)
-
- private def sort1[K: Ordering](x: Array[K], off: Int, len: Int) {
- val ord = Ordering[K]
- import ord._
-
- def swap(a: Int, b: Int) {
- val t = x(a)
- x(a) = x(b)
- x(b) = t
- }
- def vecswap(_a: Int, _b: Int, n: Int) {
- var a = _a
- var b = _b
- var i = 0
- while (i < n) {
- swap(a, b)
- i += 1
- a += 1
- b += 1
- }
- }
- def med3(a: Int, b: Int, c: Int) = {
- if (x(a) < x(b)) {
- if (x(b) < x(c)) b else if (x(a) < x(c)) c else a
- } else {
- if (x(b) > x(c)) b else if (x(a) > x(c)) c else a
- }
- }
- def sort2(off: Int, len: Int) {
- // Insertion sort on smallest arrays
- if (len < 7) {
- var i = off
- while (i < len + off) {
- var j = i
- while (j > off && x(j-1) > x(j)) {
- swap(j, j-1)
- j -= 1
+ /** Sort array `a` with quicksort, using the Ordering on its elements.
+ * This algorithm sorts in place, so no additional memory is used aside from
+ * what might be required to box individual elements during comparison.
+ */
+ def quickSort[K: Ordering](a: Array[K]): Unit = {
+ // Must have iN >= i0 or math will fail. Also, i0 >= 0.
+ def inner(a: Array[K], i0: Int, iN: Int, ord: Ordering[K]): Unit = {
+ if (iN - i0 < qsortThreshold) insertionSort(a, i0, iN, ord)
+ else {
+ var iK = (i0 + iN) >>> 1 // Unsigned div by 2
+ // Find index of median of first, central, and last elements
+ var pL =
+ if (ord.compare(a(i0), a(iN - 1)) <= 0)
+ if (ord.compare(a(i0), a(iK)) < 0)
+ if (ord.compare(a(iN - 1), a(iK)) < 0) iN - 1 else iK
+ else i0
+ else
+ if (ord.compare(a(i0), a(iK)) < 0) i0
+ else
+ if (ord.compare(a(iN - 1), a(iK)) <= 0) iN - 1
+ else iK
+ val pivot = a(pL)
+ // pL is the start of the pivot block; move it into the middle if needed
+ if (pL != iK) { a(pL) = a(iK); a(iK) = pivot; pL = iK }
+ // Elements equal to the pivot will be in range pL until pR
+ var pR = pL + 1
+ // Items known to be less than pivot are below iA (range i0 until iA)
+ var iA = i0
+ // Items known to be greater than pivot are at or above iB (range iB until iN)
+ var iB = iN
+ // Scan through everything in the buffer before the pivot(s)
+ while (pL - iA > 0) {
+ val current = a(iA)
+ ord.compare(current, pivot) match {
+ case 0 =>
+ // Swap current out with pivot block
+ a(iA) = a(pL - 1)
+ a(pL - 1) = current
+ pL -= 1
+ case x if x < 0 =>
+ // Already in place. Just update indicies.
+ iA += 1
+ case _ if iB > pR =>
+ // Wrong side. There's room on the other side, so swap
+ a(iA) = a(iB - 1)
+ a(iB - 1) = current
+ iB -= 1
+ case _ =>
+ // Wrong side and there is no room. Swap by rotating pivot block.
+ a(iA) = a(pL - 1)
+ a(pL - 1) = a(pR - 1)
+ a(pR - 1) = current
+ pL -= 1
+ pR -= 1
+ iB -= 1
}
- i += 1
}
- } else {
- // Choose a partition element, v
- var m = off + (len >> 1) // Small arrays, middle element
- if (len > 7) {
- var l = off
- var n = off + len - 1
- if (len > 40) { // Big arrays, pseudomedian of 9
- val s = len / 8
- l = med3(l, l+s, l+2*s)
- m = med3(m-s, m, m+s)
- n = med3(n-2*s, n-s, n)
+ // Get anything remaining in buffer after the pivot(s)
+ while (iB - pR > 0) {
+ val current = a(iB - 1)
+ ord.compare(current, pivot) match {
+ case 0 =>
+ // Swap current out with pivot block
+ a(iB - 1) = a(pR)
+ a(pR) = current
+ pR += 1
+ case x if x > 0 =>
+ // Already in place. Just update indices.
+ iB -= 1
+ case _ =>
+ // Wrong side and we already know there is no room. Swap by rotating pivot block.
+ a(iB - 1) = a(pR)
+ a(pR) = a(pL)
+ a(pL) = current
+ iA += 1
+ pL += 1
+ pR += 1
}
- m = med3(l, m, n) // Mid-size, med of 3
}
- val v = x(m)
-
- // Establish Invariant: v* (<v)* (>v)* v*
- var a = off
- var b = a
- var c = off + len - 1
- var d = c
- var done = false
- while (!done) {
- while (b <= c && x(b) <= v) {
- if (x(b) == v) {
- swap(a, b)
- a += 1
- }
- b += 1
- }
- while (c >= b && x(c) >= v) {
- if (x(c) == v) {
- swap(c, d)
- d -= 1
- }
- c -= 1
- }
- if (b > c) {
- done = true
- } else {
- swap(b, c)
- c -= 1
- b += 1
- }
+ // Use tail recursion on large half (Sedgewick's method) so we don't blow up the stack if pivots are poorly chosen
+ if (iA - i0 < iN - iB) {
+ inner(a, i0, iA, ord) // True recursion
+ inner(a, iB, iN, ord) // Should be tail recursion
+ }
+ else {
+ inner(a, iB, iN, ord) // True recursion
+ inner(a, i0, iA, ord) // Should be tail recursion
}
-
- // Swap partition elements back to middle
- val n = off + len
- var s = math.min(a-off, b-a)
- vecswap(off, b-s, s)
- s = math.min(d-c, n-d-1)
- vecswap(b, n-s, s)
-
- // Recursively sort non-partition-elements
- s = b - a
- if (s > 1)
- sort2(off, s)
- s = d - c
- if (s > 1)
- sort2(n-s, s)
}
}
- sort2(off, len)
+ inner(a, 0, a.length, implicitly[Ordering[K]])
}
-
- private def sort1(x: Array[Int], off: Int, len: Int) {
- def swap(a: Int, b: Int) {
- val t = x(a)
- x(a) = x(b)
- x(b) = t
+
+ private final val mergeThreshold = 32
+
+ // Ordering[T] might be slow especially for boxed primitives, so use binary search variant of insertion sort
+ // Caller must pass iN >= i0 or math will fail. Also, i0 >= 0.
+ private def insertionSort[@specialized T](a: Array[T], i0: Int, iN: Int, ord: Ordering[T]): Unit = {
+ val n = iN - i0
+ if (n < 2) return
+ if (ord.compare(a(i0), a(i0+1)) > 0) {
+ val temp = a(i0)
+ a(i0) = a(i0+1)
+ a(i0+1) = temp
}
- def vecswap(_a: Int, _b: Int, n: Int) {
- var a = _a
- var b = _b
- var i = 0
- while (i < n) {
- swap(a, b)
- i += 1
- a += 1
- b += 1
- }
- }
- def med3(a: Int, b: Int, c: Int) = {
- if (x(a) < x(b)) {
- if (x(b) < x(c)) b else if (x(a) < x(c)) c else a
- } else {
- if (x(b) > x(c)) b else if (x(a) > x(c)) c else a
- }
- }
- def sort2(off: Int, len: Int) {
- // Insertion sort on smallest arrays
- if (len < 7) {
- var i = off
- while (i < len + off) {
- var j = i
- while (j>off && x(j-1) > x(j)) {
- swap(j, j-1)
- j -= 1
- }
- i += 1
+ var m = 2
+ while (m < n) {
+ // Speed up already-sorted case by checking last element first
+ val next = a(i0 + m)
+ if (ord.compare(next, a(i0+m-1)) < 0) {
+ var iA = i0
+ var iB = i0 + m - 1
+ while (iB - iA > 1) {
+ val ix = (iA + iB) >>> 1 // Use bit shift to get unsigned div by 2
+ if (ord.compare(next, a(ix)) < 0) iB = ix
+ else iA = ix
}
- } else {
- // Choose a partition element, v
- var m = off + (len >> 1) // Small arrays, middle element
- if (len > 7) {
- var l = off
- var n = off + len - 1
- if (len > 40) { // Big arrays, pseudomedian of 9
- val s = len / 8
- l = med3(l, l+s, l+2*s)
- m = med3(m-s, m, m+s)
- n = med3(n-2*s, n-s, n)
- }
- m = med3(l, m, n) // Mid-size, med of 3
- }
- val v = x(m)
-
- // Establish Invariant: v* (<v)* (>v)* v*
- var a = off
- var b = a
- var c = off + len - 1
- var d = c
- var done = false
- while (!done) {
- while (b <= c && x(b) <= v) {
- if (x(b) == v) {
- swap(a, b)
- a += 1
- }
- b += 1
- }
- while (c >= b && x(c) >= v) {
- if (x(c) == v) {
- swap(c, d)
- d -= 1
- }
- c -= 1
- }
- if (b > c) {
- done = true
- } else {
- swap(b, c)
- c -= 1
- b += 1
- }
+ val ix = iA + (if (ord.compare(next, a(iA)) < 0) 0 else 1)
+ var i = i0 + m
+ while (i > ix) {
+ a(i) = a(i-1)
+ i -= 1
}
-
- // Swap partition elements back to middle
- val n = off + len
- var s = math.min(a-off, b-a)
- vecswap(off, b-s, s)
- s = math.min(d-c, n-d-1)
- vecswap(b, n-s, s)
-
- // Recursively sort non-partition-elements
- s = b - a
- if (s > 1)
- sort2(off, s)
- s = d - c
- if (s > 1)
- sort2(n-s, s)
+ a(ix) = next
}
+ m += 1
}
- sort2(off, len)
}
-
- private def sort1(x: Array[Double], off: Int, len: Int) {
- def swap(a: Int, b: Int) {
- val t = x(a)
- x(a) = x(b)
- x(b) = t
+
+ // Caller is required to pass iN >= i0, else math will fail. Also, i0 >= 0.
+ private def mergeSort[@specialized T: ClassTag](a: Array[T], i0: Int, iN: Int, ord: Ordering[T], scratch: Array[T] = null): Unit = {
+ if (iN - i0 < mergeThreshold) insertionSort(a, i0, iN, ord)
+ else {
+ val iK = (i0 + iN) >>> 1 // Bit shift equivalent to unsigned math, no overflow
+ val sc = if (scratch eq null) new Array[T](iK - i0) else scratch
+ mergeSort(a, i0, iK, ord, sc)
+ mergeSort(a, iK, iN, ord, sc)
+ mergeSorted(a, i0, iK, iN, ord, sc)
}
- def vecswap(_a: Int, _b: Int, n: Int) {
- var a = _a
- var b = _b
- var i = 0
- while (i < n) {
- swap(a, b)
+ }
+
+ // Must have 0 <= i0 < iK < iN
+ private def mergeSorted[@specialized T](a: Array[T], i0: Int, iK: Int, iN: Int, ord: Ordering[T], scratch: Array[T]): Unit = {
+ // Check to make sure we're not already in order
+ if (ord.compare(a(iK-1), a(iK)) > 0) {
+ var i = i0
+ val jN = iK - i0
+ var j = 0
+ while (i < iK) {
+ scratch (j) = a(i)
i += 1
- a += 1
- b += 1
- }
- }
- def med3(a: Int, b: Int, c: Int) = {
- val ab = x(a) compare x(b)
- val bc = x(b) compare x(c)
- val ac = x(a) compare x(c)
- if (ab < 0) {
- if (bc < 0) b else if (ac < 0) c else a
- } else {
- if (bc > 0) b else if (ac > 0) c else a
+ j += 1
}
- }
- def sort2(off: Int, len: Int) {
- // Insertion sort on smallest arrays
- if (len < 7) {
- var i = off
- while (i < len + off) {
- var j = i
- while (j > off && (x(j-1) compare x(j)) > 0) {
- swap(j, j-1)
- j -= 1
- }
- i += 1
- }
- } else {
- // Choose a partition element, v
- var m = off + (len >> 1) // Small arrays, middle element
- if (len > 7) {
- var l = off
- var n = off + len - 1
- if (len > 40) { // Big arrays, pseudomedian of 9
- val s = len / 8
- l = med3(l, l+s, l+2*s)
- m = med3(m-s, m, m+s)
- n = med3(n-2*s, n-s, n)
- }
- m = med3(l, m, n) // Mid-size, med of 3
- }
- val v = x(m)
-
- // Establish Invariant: v* (<v)* (>v)* v*
- var a = off
- var b = a
- var c = off + len - 1
- var d = c
- var done = false
- while (!done) {
- var bv = x(b) compare v
- while (b <= c && bv <= 0) {
- if (bv == 0) {
- swap(a, b)
- a += 1
- }
- b += 1
- if (b <= c) bv = x(b) compare v
- }
- var cv = x(c) compare v
- while (c >= b && cv >= 0) {
- if (cv == 0) {
- swap(c, d)
- d -= 1
- }
- c -= 1
- if (c >= b) cv = x(c) compare v
- }
- if (b > c) {
- done = true
- } else {
- swap(b, c)
- c -= 1
- b += 1
- }
- }
-
- // Swap partition elements back to middle
- val n = off + len
- var s = math.min(a-off, b-a)
- vecswap(off, b-s, s)
- s = math.min(d-c, n-d-1)
- vecswap(b, n-s, s)
-
- // Recursively sort non-partition-elements
- s = b - a
- if (s > 1)
- sort2(off, s)
- s = d - c
- if (s > 1)
- sort2(n-s, s)
+ var k = i0
+ j = 0
+ while (i < iN && j < jN) {
+ if (ord.compare(a(i), scratch(j)) < 0) { a(k) = a(i); i += 1 }
+ else { a(k) = scratch(j); j += 1 }
+ k += 1
}
+ while (j < jN) { a(k) = scratch(j); j += 1; k += 1 }
+ // Don't need to finish a(i) because it's already in place, k = i
}
- sort2(off, len)
}
-
- private def sort1(x: Array[Float], off: Int, len: Int) {
- def swap(a: Int, b: Int) {
- val t = x(a)
- x(a) = x(b)
- x(b) = t
+
+ // Why would you even do this?
+ private def booleanSort(a: Array[Boolean]): Unit = {
+ var i = 0
+ var n = 0
+ while (i < a.length) {
+ if (!a(i)) n += 1
+ i += 1
}
- def vecswap(_a: Int, _b: Int, n: Int) {
- var a = _a
- var b = _b
- var i = 0
- while (i < n) {
- swap(a, b)
- i += 1
- a += 1
- b += 1
- }
+ i = 0
+ while (i < n) {
+ a(i) = false
+ i += 1
}
- def med3(a: Int, b: Int, c: Int) = {
- val ab = x(a) compare x(b)
- val bc = x(b) compare x(c)
- val ac = x(a) compare x(c)
- if (ab < 0) {
- if (bc < 0) b else if (ac < 0) c else a
- } else {
- if (bc > 0) b else if (ac > 0) c else a
- }
+ while (i < a.length) {
+ a(i) = true
+ i += 1
}
- def sort2(off: Int, len: Int) {
- // Insertion sort on smallest arrays
- if (len < 7) {
- var i = off
- while (i < len + off) {
- var j = i
- while (j > off && (x(j-1) compare x(j)) > 0) {
- swap(j, j-1)
- j -= 1
- }
- i += 1
- }
- } else {
- // Choose a partition element, v
- var m = off + (len >> 1) // Small arrays, middle element
- if (len > 7) {
- var l = off
- var n = off + len - 1
- if (len > 40) { // Big arrays, pseudomedian of 9
- val s = len / 8
- l = med3(l, l+s, l+2*s)
- m = med3(m-s, m, m+s)
- n = med3(n-2*s, n-s, n)
- }
- m = med3(l, m, n) // Mid-size, med of 3
- }
- val v = x(m)
+ }
- // Establish Invariant: v* (<v)* (>v)* v*
- var a = off
- var b = a
- var c = off + len - 1
- var d = c
- var done = false
- while (!done) {
- var bv = x(b) compare v
- while (b <= c && bv <= 0) {
- if (bv == 0) {
- swap(a, b)
- a += 1
- }
- b += 1
- if (b <= c) bv = x(b) compare v
- }
- var cv = x(c) compare v
- while (c >= b && cv >= 0) {
- if (cv == 0) {
- swap(c, d)
- d -= 1
- }
- c -= 1
- if (c >= b) cv = x(c) compare v
- }
- if (b > c) {
- done = true
- } else {
- swap(b, c)
- c -= 1
- b += 1
- }
- }
+ // TODO: add upper bound: T <: AnyRef, propagate to callers below (not binary compatible)
+ // Maybe also rename all these methods to `sort`.
+ @inline private def sort[T](a: Array[T], ord: Ordering[T]): Unit = a match {
+ case _: Array[AnyRef] =>
+ // Note that runtime matches are covariant, so could actually be any Array[T] s.t. T is not primitive (even boxed value classes)
+ if (a.length > 1 && (ord eq null)) throw new NullPointerException("Ordering")
+ java.util.Arrays.sort(a, ord)
+ case a: Array[Int] => if (ord eq Ordering.Int) java.util.Arrays.sort(a) else mergeSort[Int](a, 0, a.length, ord)
+ case a: Array[Double] => mergeSort[Double](a, 0, a.length, ord) // Because not all NaNs are identical, stability is meaningful!
+ case a: Array[Long] => if (ord eq Ordering.Long) java.util.Arrays.sort(a) else mergeSort[Long](a, 0, a.length, ord)
+ case a: Array[Float] => mergeSort[Float](a, 0, a.length, ord) // Because not all NaNs are identical, stability is meaningful!
+ case a: Array[Char] => if (ord eq Ordering.Char) java.util.Arrays.sort(a) else mergeSort[Char](a, 0, a.length, ord)
+ case a: Array[Byte] => if (ord eq Ordering.Byte) java.util.Arrays.sort(a) else mergeSort[Byte](a, 0, a.length, ord)
+ case a: Array[Short] => if (ord eq Ordering.Short) java.util.Arrays.sort(a) else mergeSort[Short](a, 0, a.length, ord)
+ case a: Array[Boolean] => if (ord eq Ordering.Boolean) booleanSort(a) else mergeSort[Boolean](a, 0, a.length, ord)
+ // Array[Unit] is matched as an Array[AnyRef] due to covariance in runtime matching. Not worth catching it as a special case.
+ case null => throw new NullPointerException
+ }
- // Swap partition elements back to middle
- val n = off + len
- var s = math.min(a-off, b-a)
- vecswap(off, b-s, s)
- s = math.min(d-c, n-d-1)
- vecswap(b, n-s, s)
+ // TODO: remove unnecessary ClassTag (not binary compatible)
+ /** Sort array `a` using the Ordering on its elements, preserving the original ordering where possible. Uses `java.util.Arrays.sort` unless `K` is a primitive type. */
+ def stableSort[K: ClassTag: Ordering](a: Array[K]): Unit = sort(a, Ordering[K])
- // Recursively sort non-partition-elements
- s = b - a
- if (s > 1)
- sort2(off, s)
- s = d - c
- if (s > 1)
- sort2(n-s, s)
- }
- }
- sort2(off, len)
+ // TODO: Remove unnecessary ClassTag (not binary compatible)
+ // TODO: make this fast for primitive K (could be specialized if it didn't go through Ordering)
+ /** Sort array `a` using function `f` that computes the less-than relation for each element. Uses `java.util.Arrays.sort` unless `K` is a primitive type. */
+ def stableSort[K: ClassTag](a: Array[K], f: (K, K) => Boolean): Unit = sort(a, Ordering fromLessThan f)
+
+ /** A sorted Array, using the Ordering for the elements in the sequence `a`. Uses `java.util.Arrays.sort` unless `K` is a primitive type. */
+ def stableSort[K: ClassTag: Ordering](a: Seq[K]): Array[K] = {
+ val ret = a.toArray
+ sort(ret, Ordering[K])
+ ret
}
- private def stableSort[K : ClassTag](a: Array[K], lo: Int, hi: Int, scratch: Array[K], f: (K,K) => Boolean) {
- if (lo < hi) {
- val mid = (lo+hi) / 2
- stableSort(a, lo, mid, scratch, f)
- stableSort(a, mid+1, hi, scratch, f)
- var k, t_lo = lo
- var t_hi = mid + 1
- while (k <= hi) {
- if ((t_lo <= mid) && ((t_hi > hi) || (!f(a(t_hi), a(t_lo))))) {
- scratch(k) = a(t_lo)
- t_lo += 1
- } else {
- scratch(k) = a(t_hi)
- t_hi += 1
- }
- k += 1
- }
- k = lo
- while (k <= hi) {
- a(k) = scratch(k)
- k += 1
- }
- }
+ // TODO: make this fast for primitive K (could be specialized if it didn't go through Ordering)
+ /** A sorted Array, given a function `f` that computes the less-than relation for each item in the sequence `a`. Uses `java.util.Arrays.sort` unless `K` is a primitive type. */
+ def stableSort[K: ClassTag](a: Seq[K], f: (K, K) => Boolean): Array[K] = {
+ val ret = a.toArray
+ sort(ret, Ordering fromLessThan f)
+ ret
+ }
+
+ /** A sorted Array, given an extraction function `f` that returns an ordered key for each item in the sequence `a`. Uses `java.util.Arrays.sort` unless `K` is a primitive type. */
+ def stableSort[K: ClassTag, M: Ordering](a: Seq[K], f: K => M): Array[K] = {
+ val ret = a.toArray
+ sort(ret, Ordering[M] on f)
+ ret
}
}
diff --git a/test/junit/scala/util/SortingTest.scala b/test/junit/scala/util/SortingTest.scala
new file mode 100644
index 0000000000..15a00c8903
--- /dev/null
+++ b/test/junit/scala/util/SortingTest.scala
@@ -0,0 +1,69 @@
+package scala.util
+
+import org.junit.Test
+import org.junit.Assert._
+import scala.math.{ Ordered, Ordering }
+import scala.reflect.ClassTag
+
+class SortingTest {
+ case class N(i: Int, j: Int) extends Ordered[N] { def compare(n: N) = if (i < n.i) -1 else if (i > n.i) 1 else 0 }
+
+ def mkA(n: Int, max: Int) = Array.tabulate(n)(i => N(util.Random.nextInt(max), i))
+
+ def isStable(a: Array[N]): Boolean = { var i = 1; while (i < a.length) { if (a(i).i < a(i-1).i || (a(i).i == a(i-1).i && a(i).j < a(i-1).j)) return false; i += 1 }; true }
+
+ def isAntistable(a: Array[N]): Boolean =
+ { var i = 1; while (i < a.length) { if (a(i).i > a(i-1).i || (a(i).i == a(i-1).i && a(i).j < a(i-1).j)) return false; i += 1 }; true }
+
+ def isSorted(a: Array[N]): Boolean = { var i = 1; while (i < a.length) { if (a(i).i < a(i-1).i) return false; i += 1 }; true }
+
+ def isAntisorted(a: Array[N]): Boolean = { var i = 1; while (i < a.length) { if (a(i).i > a(i-1).i) return false; i += 1 }; true }
+
+ val sizes = Seq.range(0, 65) ++ Seq(256, 1024, 9121, 65539)
+ val variety = Seq(1, 2, 10, 100, 1000, Int.MaxValue)
+ val workLimit = 1e6
+ val rng = new util.Random(198571)
+
+ val backwardsN = Ordering by ((n: N) => -n.i)
+
+ def runOneTest(size: Int, variety: Int): Unit = {
+ val xs = Array.tabulate(size)(i => N(rng.nextInt(variety), i))
+ val ys = Array.range(0, xs.length)
+ val zs = { val temp = xs.clone; java.util.Arrays.sort(temp, new java.util.Comparator[N] { def compare(a: N, b: N) = a.compare(b) }); temp }
+ val qxs = { val temp = xs.clone; Sorting.quickSort(temp); temp }
+ val pxs = { val temp = xs.clone; Sorting.quickSort(temp)(backwardsN); temp }
+ val sxs = { val temp = xs.clone; Sorting.stableSort(temp); temp }
+ val rxs = { val temp = xs.clone; Sorting.stableSort(temp)(implicitly[ClassTag[N]], backwardsN); temp }
+ val sys = Sorting.stableSort(ys.clone: Seq[Int], (i: Int) => xs(i))
+
+ assertTrue("Quicksort should be in order", isSorted(qxs))
+ assertTrue("Quicksort should be in reverse order", isAntisorted(pxs))
+ assertTrue("Stable sort should be sorted and stable", isStable(sxs))
+ assertTrue("Stable sort should be reverse sorted but stable", isAntistable(rxs))
+ assertTrue("Stable sorting by proxy should produce sorted stable list", isStable(sys.map(i => xs(i))))
+ assertTrue("Quicksort should produce canonical ordering", (qxs zip zs).forall{ case (a,b) => a.i == b.i })
+ assertTrue("Reverse quicksort should produce canonical ordering", (pxs.reverse zip zs).forall{ case (a,b) => a.i == b.i })
+ assertTrue("Stable sort should produce exact ordering", (sxs zip zs).forall{ case (a,b) => a == b })
+ assertTrue("Reverse stable sort should produce canonical ordering", (rxs.reverse zip zs).forall{ case (a,b) => a.i == b.i })
+ assertTrue("Proxy sort and direct sort should produce exactly the same thing", (sxs zip sys.map(i => xs(i))).forall{ case (a,b) => a == b })
+ }
+
+ @Test def testSortConsistency: Unit = {
+ for {
+ size <- sizes
+ v <- variety
+ i <- 0 until math.min(100, math.max(math.min(math.floor(math.pow(v, size)/2), math.ceil(workLimit / (math.log(math.max(2,size))/math.log(2) * size))), 1).toInt)
+ } runOneTest(size, v)
+
+ for (size <- sizes) {
+ val b = Array.fill(size)(rng.nextBoolean)
+ val bfwd = Sorting.stableSort(b.clone: Seq[Boolean])
+ val bbkw = Sorting.stableSort(b.clone: Seq[Boolean], (x: Boolean, y: Boolean) => x && !y)
+ assertTrue("All falses should be first", bfwd.dropWhile(_ == false).forall(_ == true))
+ assertTrue("All falses should be last when sorted backwards", bbkw.dropWhile(_ == true).forall(_ == false))
+ assertTrue("Sorting booleans should preserve the number of trues", b.count(_ == true) == bfwd.count(_ == true))
+ assertTrue("Backwards sorting booleans should preserve the number of trues", b.count(_ == true) == bbkw.count(_ == true))
+ assertTrue("Sorting should not change the sizes of arrays", b.length == bfwd.length && b.length == bbkw.length)
+ }
+ }
+}