From dab1b0ae29a6d3017bdca23464f22a51d51eaae1 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 29 Sep 2014 11:25:32 -0700 Subject: [SPARK-3032][Shuffle] Fix key comparison integer overflow introduced sorting exception Previous key comparison in `ExternalSorter` will get wrong sorting result or exception when key comparison overflows, details can be seen in [SPARK-3032](https://issues.apache.org/jira/browse/SPARK-3032). Here fix this and add a unit test to prove it. Author: jerryshao Closes #2514 from jerryshao/SPARK-3032 and squashes the following commits: 6f3c302 [jerryshao] Improve the unit test according to comments 01911e6 [jerryshao] Change the test to show the contract violate exception 83acb38 [jerryshao] Minor changes according to comments fa2a08f [jerryshao] Fix key comparison integer overflow introduced sorting exception --- .../spark/util/collection/ExternalSorter.scala | 2 +- .../util/collection/ExternalSorterSuite.scala | 55 ++++++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) (limited to 'core') diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 0a152cb97a..644fa36818 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -144,7 +144,7 @@ private[spark] class ExternalSorter[K, V, C]( override def compare(a: K, b: K): Int = { val h1 = if (a == null) 0 else a.hashCode() val h2 = if (b == null) 0 else b.hashCode() - h1 - h2 + if (h1 < h2) -1 else if (h1 == h2) 0 else 1 } }) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 706faed980..f26e40fbd4 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -24,6 +24,8 @@ import org.scalatest.{PrivateMethodTester, FunSuite} import org.apache.spark._ import org.apache.spark.SparkContext._ +import scala.util.Random + class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMethodTester { private def createSparkConf(loadDefaults: Boolean): SparkConf = { val conf = new SparkConf(loadDefaults) @@ -707,4 +709,57 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe Some(agg), Some(new HashPartitioner(FEW_PARTITIONS)), None, None) assertDidNotBypassMergeSort(sorter4) } + + test("sort without breaking sorting contracts") { + val conf = createSparkConf(true) + conf.set("spark.shuffle.memoryFraction", "0.01") + conf.set("spark.shuffle.manager", "sort") + sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + + // Using wrongOrdering to show integer overflow introduced exception. + val rand = new Random(100L) + val wrongOrdering = new Ordering[String] { + override def compare(a: String, b: String) = { + val h1 = if (a == null) 0 else a.hashCode() + val h2 = if (b == null) 0 else b.hashCode() + h1 - h2 + } + } + + val testData = Array.tabulate(100000) { _ => rand.nextInt().toString } + + val sorter1 = new ExternalSorter[String, String, String]( + None, None, Some(wrongOrdering), None) + val thrown = intercept[IllegalArgumentException] { + sorter1.insertAll(testData.iterator.map(i => (i, i))) + sorter1.iterator + } + + assert(thrown.getClass() === classOf[IllegalArgumentException]) + assert(thrown.getMessage().contains("Comparison method violates its general contract")) + sorter1.stop() + + // Using aggregation and external spill to make sure ExternalSorter using + // partitionKeyComparator. + def createCombiner(i: String) = ArrayBuffer(i) + def mergeValue(c: ArrayBuffer[String], i: String) = c += i + def mergeCombiners(c1: ArrayBuffer[String], c2: ArrayBuffer[String]) = c1 ++= c2 + + val agg = new Aggregator[String, String, ArrayBuffer[String]]( + createCombiner, mergeValue, mergeCombiners) + + val sorter2 = new ExternalSorter[String, String, ArrayBuffer[String]]( + Some(agg), None, None, None) + sorter2.insertAll(testData.iterator.map(i => (i, i))) + + // To validate the hash ordering of key + var minKey = Int.MinValue + sorter2.iterator.foreach { case (k, v) => + val h = k.hashCode() + assert(h >= minKey) + minKey = h + } + + sorter2.stop() + } } -- cgit v1.2.3