aboutsummaryrefslogtreecommitdiff
path: root/unsafe
diff options
context:
space:
mode:
authorNong Li <nong@databricks.com>2015-12-01 12:59:53 -0800
committerYin Huai <yhuai@databricks.com>2015-12-01 12:59:53 -0800
commit2cef1cdfbb5393270ae83179b6a4e50c3cbf9e93 (patch)
tree909226849a94a4efb0384665ae663f23295d8891 /unsafe
parent34e7093c1131162b3aa05b65a19a633a0b5b633e (diff)
downloadspark-2cef1cdfbb5393270ae83179b6a4e50c3cbf9e93.tar.gz
spark-2cef1cdfbb5393270ae83179b6a4e50c3cbf9e93.tar.bz2
spark-2cef1cdfbb5393270ae83179b6a4e50c3cbf9e93.zip
[SPARK-12030] Fix Platform.copyMemory to handle overlapping regions.
This bug was exposed as memory corruption in Timsort which uses copyMemory to copy large regions that can overlap. The prior implementation did not handle this case half the time and always copied forward, resulting in the data being corrupt. Author: Nong Li <nong@databricks.com> Closes #10068 from nongli/spark-12030.
Diffstat (limited to 'unsafe')
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/Platform.java27
-rw-r--r--unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java61
2 files changed, 82 insertions, 6 deletions
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
index 1c16da9829..0d6b215fe5 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
@@ -107,12 +107,27 @@ public final class Platform {
public static void copyMemory(
Object src, long srcOffset, Object dst, long dstOffset, long length) {
- while (length > 0) {
- long size = Math.min(length, UNSAFE_COPY_THRESHOLD);
- _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size);
- length -= size;
- srcOffset += size;
- dstOffset += size;
+ // Check if dstOffset is before or after srcOffset to determine if we should copy
+ // forward or backwards. This is necessary in case src and dst overlap.
+ if (dstOffset < srcOffset) {
+ while (length > 0) {
+ long size = Math.min(length, UNSAFE_COPY_THRESHOLD);
+ _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size);
+ length -= size;
+ srcOffset += size;
+ dstOffset += size;
+ }
+ } else {
+ srcOffset += length;
+ dstOffset += length;
+ while (length > 0) {
+ long size = Math.min(length, UNSAFE_COPY_THRESHOLD);
+ srcOffset -= size;
+ dstOffset -= size;
+ _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size);
+ length -= size;
+ }
+
}
}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
new file mode 100644
index 0000000000..693ec6ec58
--- /dev/null
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class PlatformUtilSuite {
+
+ @Test
+ public void overlappingCopyMemory() {
+ byte[] data = new byte[3 * 1024 * 1024];
+ int size = 2 * 1024 * 1024;
+ for (int i = 0; i < data.length; ++i) {
+ data[i] = (byte)i;
+ }
+
+ Platform.copyMemory(data, Platform.BYTE_ARRAY_OFFSET, data, Platform.BYTE_ARRAY_OFFSET, size);
+ for (int i = 0; i < data.length; ++i) {
+ Assert.assertEquals((byte)i, data[i]);
+ }
+
+ Platform.copyMemory(
+ data,
+ Platform.BYTE_ARRAY_OFFSET + 1,
+ data,
+ Platform.BYTE_ARRAY_OFFSET,
+ size);
+ for (int i = 0; i < size; ++i) {
+ Assert.assertEquals((byte)(i + 1), data[i]);
+ }
+
+ for (int i = 0; i < data.length; ++i) {
+ data[i] = (byte)i;
+ }
+ Platform.copyMemory(
+ data,
+ Platform.BYTE_ARRAY_OFFSET,
+ data,
+ Platform.BYTE_ARRAY_OFFSET + 1,
+ size);
+ for (int i = 0; i < size; ++i) {
+ Assert.assertEquals((byte)i, data[i + 1]);
+ }
+ }
+}