From d93b6552473468df297a08c0bef9ea0bf0f5c13a Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 19 Nov 2016 21:50:20 -0800 Subject: [SPARK-18458][CORE] Fix signed integer overflow problem at an expression in RadixSort.java ## What changes were proposed in this pull request? This PR avoids that a result of an expression is negative due to signed integer overflow (e.g. 0x10?????? * 8 < 0). This PR casts each operand to `long` before executing a calculation. Since the result is interpreted as long, the result of the expression is positive. ## How was this patch tested? Manually executed query82 of TPC-DS with 100TB Author: Kazuaki Ishizaki Closes #15907 from kiszk/SPARK-18458. --- .../util/collection/unsafe/sort/RadixSort.java | 48 +++++++++++----------- .../unsafe/sort/UnsafeInMemorySorter.java | 2 +- .../collection/unsafe/sort/RadixSortSuite.scala | 28 ++++++------- 3 files changed, 40 insertions(+), 38 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java index 404361734a..3dd3184710 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java @@ -17,6 +17,8 @@ package org.apache.spark.util.collection.unsafe.sort; +import com.google.common.primitives.Ints; + import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; @@ -40,14 +42,14 @@ public class RadixSort { * of always copying the data back to position zero for efficiency. */ public static int sort( - LongArray array, int numRecords, int startByteIndex, int endByteIndex, + LongArray array, long numRecords, int startByteIndex, int endByteIndex, boolean desc, boolean signed) { assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0"; assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7"; assert endByteIndex > startByteIndex; assert numRecords * 2 <= array.size(); - int inIndex = 0; - int outIndex = numRecords; + long inIndex = 0; + long outIndex = numRecords; if (numRecords > 0) { long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex); for (int i = startByteIndex; i <= endByteIndex; i++) { @@ -55,13 +57,13 @@ public class RadixSort { sortAtByte( array, numRecords, counts[i], i, inIndex, outIndex, desc, signed && i == endByteIndex); - int tmp = inIndex; + long tmp = inIndex; inIndex = outIndex; outIndex = tmp; } } } - return inIndex; + return Ints.checkedCast(inIndex); } /** @@ -78,14 +80,14 @@ public class RadixSort { * @param signed whether this is a signed (two's complement) sort (only applies to last byte). */ private static void sortAtByte( - LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex, + LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex, boolean desc, boolean signed) { assert counts.length == 256; long[] offsets = transformCountsToOffsets( - counts, numRecords, array.getBaseOffset() + outIndex * 8, 8, desc, signed); + counts, numRecords, array.getBaseOffset() + outIndex * 8L, 8, desc, signed); Object baseObject = array.getBaseObject(); - long baseOffset = array.getBaseOffset() + inIndex * 8; - long maxOffset = baseOffset + numRecords * 8; + long baseOffset = array.getBaseOffset() + inIndex * 8L; + long maxOffset = baseOffset + numRecords * 8L; for (long offset = baseOffset; offset < maxOffset; offset += 8) { long value = Platform.getLong(baseObject, offset); int bucket = (int)((value >>> (byteIdx * 8)) & 0xff); @@ -106,13 +108,13 @@ public class RadixSort { * significant byte. If the byte does not need sorting the array will be null. */ private static long[][] getCounts( - LongArray array, int numRecords, int startByteIndex, int endByteIndex) { + LongArray array, long numRecords, int startByteIndex, int endByteIndex) { long[][] counts = new long[8][]; // Optimization: do a fast pre-pass to determine which byte indices we can skip for sorting. // If all the byte values at a particular index are the same we don't need to count it. long bitwiseMax = 0; long bitwiseMin = -1L; - long maxOffset = array.getBaseOffset() + numRecords * 8; + long maxOffset = array.getBaseOffset() + numRecords * 8L; Object baseObject = array.getBaseObject(); for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) { long value = Platform.getLong(baseObject, offset); @@ -146,18 +148,18 @@ public class RadixSort { * @return the input counts array. */ private static long[] transformCountsToOffsets( - long[] counts, int numRecords, long outputOffset, int bytesPerRecord, + long[] counts, long numRecords, long outputOffset, long bytesPerRecord, boolean desc, boolean signed) { assert counts.length == 256; int start = signed ? 128 : 0; // output the negative records first (values 129-255). if (desc) { - int pos = numRecords; + long pos = numRecords; for (int i = start; i < start + 256; i++) { pos -= counts[i & 0xff]; counts[i & 0xff] = outputOffset + pos * bytesPerRecord; } } else { - int pos = 0; + long pos = 0; for (int i = start; i < start + 256; i++) { long tmp = counts[i & 0xff]; counts[i & 0xff] = outputOffset + pos * bytesPerRecord; @@ -176,8 +178,8 @@ public class RadixSort { */ public static int sortKeyPrefixArray( LongArray array, - int startIndex, - int numRecords, + long startIndex, + long numRecords, int startByteIndex, int endByteIndex, boolean desc, @@ -186,8 +188,8 @@ public class RadixSort { assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7"; assert endByteIndex > startByteIndex; assert numRecords * 4 <= array.size(); - int inIndex = startIndex; - int outIndex = startIndex + numRecords * 2; + long inIndex = startIndex; + long outIndex = startIndex + numRecords * 2L; if (numRecords > 0) { long[][] counts = getKeyPrefixArrayCounts( array, startIndex, numRecords, startByteIndex, endByteIndex); @@ -196,13 +198,13 @@ public class RadixSort { sortKeyPrefixArrayAtByte( array, numRecords, counts[i], i, inIndex, outIndex, desc, signed && i == endByteIndex); - int tmp = inIndex; + long tmp = inIndex; inIndex = outIndex; outIndex = tmp; } } } - return inIndex; + return Ints.checkedCast(inIndex); } /** @@ -210,7 +212,7 @@ public class RadixSort { * getCounts with some added parameters but that seems to hurt in benchmarks. */ private static long[][] getKeyPrefixArrayCounts( - LongArray array, int startIndex, int numRecords, int startByteIndex, int endByteIndex) { + LongArray array, long startIndex, long numRecords, int startByteIndex, int endByteIndex) { long[][] counts = new long[8][]; long bitwiseMax = 0; long bitwiseMin = -1L; @@ -238,11 +240,11 @@ public class RadixSort { * Specialization of sortAtByte() for key-prefix arrays. */ private static void sortKeyPrefixArrayAtByte( - LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex, + LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex, boolean desc, boolean signed) { assert counts.length == 256; long[] offsets = transformCountsToOffsets( - counts, numRecords, array.getBaseOffset() + outIndex * 8, 16, desc, signed); + counts, numRecords, array.getBaseOffset() + outIndex * 8L, 16, desc, signed); Object baseObject = array.getBaseObject(); long baseOffset = array.getBaseOffset() + inIndex * 8L; long maxOffset = baseOffset + numRecords * 16L; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 2a71e68ada..252a35ec6b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -322,7 +322,7 @@ public final class UnsafeInMemorySorter { if (sortComparator != null) { if (this.radixSortSupport != null) { offset = RadixSort.sortKeyPrefixArray( - array, nullBoundaryPos, (pos - nullBoundaryPos) / 2, 0, 7, + array, nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); } else { MemoryBlock unused = new MemoryBlock( diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala index 366ffda778..d5956ea320 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala @@ -22,6 +22,8 @@ import java.util.{Arrays, Comparator} import scala.util.Random +import com.google.common.primitives.Ints + import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.unsafe.array.LongArray @@ -30,7 +32,7 @@ import org.apache.spark.util.collection.Sorter import org.apache.spark.util.random.XORShiftRandom class RadixSortSuite extends SparkFunSuite with Logging { - private val N = 10000 // scale this down for more readable results + private val N = 10000L // scale this down for more readable results /** * Describes a type of sort to test, e.g. two's complement descending. Each sort type has @@ -73,22 +75,22 @@ class RadixSortSuite extends SparkFunSuite with Logging { }, 2, 4, false, false, true)) - private def generateTestData(size: Int, rand: => Long): (Array[JLong], LongArray) = { - val ref = Array.tabulate[Long](size) { i => rand } - val extended = ref ++ Array.fill[Long](size)(0) + private def generateTestData(size: Long, rand: => Long): (Array[JLong], LongArray) = { + val ref = Array.tabulate[Long](Ints.checkedCast(size)) { i => rand } + val extended = ref ++ Array.fill[Long](Ints.checkedCast(size))(0) (ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended))) } - private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = { - val ref = Array.tabulate[Long](size * 2) { i => rand } - val extended = ref ++ Array.fill[Long](size * 2)(0) + private def generateKeyPrefixTestData(size: Long, rand: => Long): (LongArray, LongArray) = { + val ref = Array.tabulate[Long](Ints.checkedCast(size * 2)) { i => rand } + val extended = ref ++ Array.fill[Long](Ints.checkedCast(size * 2))(0) (new LongArray(MemoryBlock.fromLongArray(ref)), new LongArray(MemoryBlock.fromLongArray(extended))) } - private def collectToArray(array: LongArray, offset: Int, length: Int): Array[Long] = { + private def collectToArray(array: LongArray, offset: Int, length: Long): Array[Long] = { var i = 0 - val out = new Array[Long](length) + val out = new Array[Long](Ints.checkedCast(length)) while (i < length) { out(i) = array.get(offset + i) i += 1 @@ -107,15 +109,13 @@ class RadixSortSuite extends SparkFunSuite with Logging { } } - private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) { + private def referenceKeyPrefixSort(buf: LongArray, lo: Long, hi: Long, refCmp: PrefixComparator) { val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort( - buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] { + buf, Ints.checkedCast(lo), Ints.checkedCast(hi), new Comparator[RecordPointerAndKeyPrefix] { override def compare( r1: RecordPointerAndKeyPrefix, - r2: RecordPointerAndKeyPrefix): Int = { - refCmp.compare(r1.keyPrefix, r2.keyPrefix) - } + r2: RecordPointerAndKeyPrefix): Int = refCmp.compare(r1.keyPrefix, r2.keyPrefix) }) } -- cgit v1.2.3