aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKazuaki Ishizaki <ishizaki@jp.ibm.com>2016-11-19 21:50:20 -0800
committerReynold Xin <rxin@databricks.com>2016-11-19 21:50:20 -0800
commitd93b6552473468df297a08c0bef9ea0bf0f5c13a (patch)
treebf244f58384ce0725bfa1b9c7c4c2ca4180076ad
parent856e0042007c789dda4539fb19a5d4580999fbf4 (diff)
downloadspark-d93b6552473468df297a08c0bef9ea0bf0f5c13a.tar.gz
spark-d93b6552473468df297a08c0bef9ea0bf0f5c13a.tar.bz2
spark-d93b6552473468df297a08c0bef9ea0bf0f5c13a.zip
[SPARK-18458][CORE] Fix signed integer overflow problem at an expression in RadixSort.java
## What changes were proposed in this pull request? This PR avoids that a result of an expression is negative due to signed integer overflow (e.g. 0x10?????? * 8 < 0). This PR casts each operand to `long` before executing a calculation. Since the result is interpreted as long, the result of the expression is positive. ## How was this patch tested? Manually executed query82 of TPC-DS with 100TB Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com> Closes #15907 from kiszk/SPARK-18458.
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java48
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java2
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala28
3 files changed, 40 insertions, 38 deletions
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
index 404361734a..3dd3184710 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
@@ -17,6 +17,8 @@
package org.apache.spark.util.collection.unsafe.sort;
+import com.google.common.primitives.Ints;
+
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
@@ -40,14 +42,14 @@ public class RadixSort {
* of always copying the data back to position zero for efficiency.
*/
public static int sort(
- LongArray array, int numRecords, int startByteIndex, int endByteIndex,
+ LongArray array, long numRecords, int startByteIndex, int endByteIndex,
boolean desc, boolean signed) {
assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0";
assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
assert endByteIndex > startByteIndex;
assert numRecords * 2 <= array.size();
- int inIndex = 0;
- int outIndex = numRecords;
+ long inIndex = 0;
+ long outIndex = numRecords;
if (numRecords > 0) {
long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex);
for (int i = startByteIndex; i <= endByteIndex; i++) {
@@ -55,13 +57,13 @@ public class RadixSort {
sortAtByte(
array, numRecords, counts[i], i, inIndex, outIndex,
desc, signed && i == endByteIndex);
- int tmp = inIndex;
+ long tmp = inIndex;
inIndex = outIndex;
outIndex = tmp;
}
}
}
- return inIndex;
+ return Ints.checkedCast(inIndex);
}
/**
@@ -78,14 +80,14 @@ public class RadixSort {
* @param signed whether this is a signed (two's complement) sort (only applies to last byte).
*/
private static void sortAtByte(
- LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex,
+ LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex,
boolean desc, boolean signed) {
assert counts.length == 256;
long[] offsets = transformCountsToOffsets(
- counts, numRecords, array.getBaseOffset() + outIndex * 8, 8, desc, signed);
+ counts, numRecords, array.getBaseOffset() + outIndex * 8L, 8, desc, signed);
Object baseObject = array.getBaseObject();
- long baseOffset = array.getBaseOffset() + inIndex * 8;
- long maxOffset = baseOffset + numRecords * 8;
+ long baseOffset = array.getBaseOffset() + inIndex * 8L;
+ long maxOffset = baseOffset + numRecords * 8L;
for (long offset = baseOffset; offset < maxOffset; offset += 8) {
long value = Platform.getLong(baseObject, offset);
int bucket = (int)((value >>> (byteIdx * 8)) & 0xff);
@@ -106,13 +108,13 @@ public class RadixSort {
* significant byte. If the byte does not need sorting the array will be null.
*/
private static long[][] getCounts(
- LongArray array, int numRecords, int startByteIndex, int endByteIndex) {
+ LongArray array, long numRecords, int startByteIndex, int endByteIndex) {
long[][] counts = new long[8][];
// Optimization: do a fast pre-pass to determine which byte indices we can skip for sorting.
// If all the byte values at a particular index are the same we don't need to count it.
long bitwiseMax = 0;
long bitwiseMin = -1L;
- long maxOffset = array.getBaseOffset() + numRecords * 8;
+ long maxOffset = array.getBaseOffset() + numRecords * 8L;
Object baseObject = array.getBaseObject();
for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) {
long value = Platform.getLong(baseObject, offset);
@@ -146,18 +148,18 @@ public class RadixSort {
* @return the input counts array.
*/
private static long[] transformCountsToOffsets(
- long[] counts, int numRecords, long outputOffset, int bytesPerRecord,
+ long[] counts, long numRecords, long outputOffset, long bytesPerRecord,
boolean desc, boolean signed) {
assert counts.length == 256;
int start = signed ? 128 : 0; // output the negative records first (values 129-255).
if (desc) {
- int pos = numRecords;
+ long pos = numRecords;
for (int i = start; i < start + 256; i++) {
pos -= counts[i & 0xff];
counts[i & 0xff] = outputOffset + pos * bytesPerRecord;
}
} else {
- int pos = 0;
+ long pos = 0;
for (int i = start; i < start + 256; i++) {
long tmp = counts[i & 0xff];
counts[i & 0xff] = outputOffset + pos * bytesPerRecord;
@@ -176,8 +178,8 @@ public class RadixSort {
*/
public static int sortKeyPrefixArray(
LongArray array,
- int startIndex,
- int numRecords,
+ long startIndex,
+ long numRecords,
int startByteIndex,
int endByteIndex,
boolean desc,
@@ -186,8 +188,8 @@ public class RadixSort {
assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
assert endByteIndex > startByteIndex;
assert numRecords * 4 <= array.size();
- int inIndex = startIndex;
- int outIndex = startIndex + numRecords * 2;
+ long inIndex = startIndex;
+ long outIndex = startIndex + numRecords * 2L;
if (numRecords > 0) {
long[][] counts = getKeyPrefixArrayCounts(
array, startIndex, numRecords, startByteIndex, endByteIndex);
@@ -196,13 +198,13 @@ public class RadixSort {
sortKeyPrefixArrayAtByte(
array, numRecords, counts[i], i, inIndex, outIndex,
desc, signed && i == endByteIndex);
- int tmp = inIndex;
+ long tmp = inIndex;
inIndex = outIndex;
outIndex = tmp;
}
}
}
- return inIndex;
+ return Ints.checkedCast(inIndex);
}
/**
@@ -210,7 +212,7 @@ public class RadixSort {
* getCounts with some added parameters but that seems to hurt in benchmarks.
*/
private static long[][] getKeyPrefixArrayCounts(
- LongArray array, int startIndex, int numRecords, int startByteIndex, int endByteIndex) {
+ LongArray array, long startIndex, long numRecords, int startByteIndex, int endByteIndex) {
long[][] counts = new long[8][];
long bitwiseMax = 0;
long bitwiseMin = -1L;
@@ -238,11 +240,11 @@ public class RadixSort {
* Specialization of sortAtByte() for key-prefix arrays.
*/
private static void sortKeyPrefixArrayAtByte(
- LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex,
+ LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex,
boolean desc, boolean signed) {
assert counts.length == 256;
long[] offsets = transformCountsToOffsets(
- counts, numRecords, array.getBaseOffset() + outIndex * 8, 16, desc, signed);
+ counts, numRecords, array.getBaseOffset() + outIndex * 8L, 16, desc, signed);
Object baseObject = array.getBaseObject();
long baseOffset = array.getBaseOffset() + inIndex * 8L;
long maxOffset = baseOffset + numRecords * 16L;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 2a71e68ada..252a35ec6b 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -322,7 +322,7 @@ public final class UnsafeInMemorySorter {
if (sortComparator != null) {
if (this.radixSortSupport != null) {
offset = RadixSort.sortKeyPrefixArray(
- array, nullBoundaryPos, (pos - nullBoundaryPos) / 2, 0, 7,
+ array, nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0, 7,
radixSortSupport.sortDescending(), radixSortSupport.sortSigned());
} else {
MemoryBlock unused = new MemoryBlock(
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
index 366ffda778..d5956ea320 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
@@ -22,6 +22,8 @@ import java.util.{Arrays, Comparator}
import scala.util.Random
+import com.google.common.primitives.Ints
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.internal.Logging
import org.apache.spark.unsafe.array.LongArray
@@ -30,7 +32,7 @@ import org.apache.spark.util.collection.Sorter
import org.apache.spark.util.random.XORShiftRandom
class RadixSortSuite extends SparkFunSuite with Logging {
- private val N = 10000 // scale this down for more readable results
+ private val N = 10000L // scale this down for more readable results
/**
* Describes a type of sort to test, e.g. two's complement descending. Each sort type has
@@ -73,22 +75,22 @@ class RadixSortSuite extends SparkFunSuite with Logging {
},
2, 4, false, false, true))
- private def generateTestData(size: Int, rand: => Long): (Array[JLong], LongArray) = {
- val ref = Array.tabulate[Long](size) { i => rand }
- val extended = ref ++ Array.fill[Long](size)(0)
+ private def generateTestData(size: Long, rand: => Long): (Array[JLong], LongArray) = {
+ val ref = Array.tabulate[Long](Ints.checkedCast(size)) { i => rand }
+ val extended = ref ++ Array.fill[Long](Ints.checkedCast(size))(0)
(ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended)))
}
- private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = {
- val ref = Array.tabulate[Long](size * 2) { i => rand }
- val extended = ref ++ Array.fill[Long](size * 2)(0)
+ private def generateKeyPrefixTestData(size: Long, rand: => Long): (LongArray, LongArray) = {
+ val ref = Array.tabulate[Long](Ints.checkedCast(size * 2)) { i => rand }
+ val extended = ref ++ Array.fill[Long](Ints.checkedCast(size * 2))(0)
(new LongArray(MemoryBlock.fromLongArray(ref)),
new LongArray(MemoryBlock.fromLongArray(extended)))
}
- private def collectToArray(array: LongArray, offset: Int, length: Int): Array[Long] = {
+ private def collectToArray(array: LongArray, offset: Int, length: Long): Array[Long] = {
var i = 0
- val out = new Array[Long](length)
+ val out = new Array[Long](Ints.checkedCast(length))
while (i < length) {
out(i) = array.get(offset + i)
i += 1
@@ -107,15 +109,13 @@ class RadixSortSuite extends SparkFunSuite with Logging {
}
}
- private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) {
+ private def referenceKeyPrefixSort(buf: LongArray, lo: Long, hi: Long, refCmp: PrefixComparator) {
val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt)))
new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort(
- buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] {
+ buf, Ints.checkedCast(lo), Ints.checkedCast(hi), new Comparator[RecordPointerAndKeyPrefix] {
override def compare(
r1: RecordPointerAndKeyPrefix,
- r2: RecordPointerAndKeyPrefix): Int = {
- refCmp.compare(r1.keyPrefix, r2.keyPrefix)
- }
+ r2: RecordPointerAndKeyPrefix): Int = refCmp.compare(r1.keyPrefix, r2.keyPrefix)
})
}