aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java16
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala39
2 files changed, 53 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
index af61e2011f..0e4264fe8d 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
@@ -45,7 +45,13 @@ public class BufferHolder {
}
public BufferHolder(UnsafeRow row, int initialSize) {
- this.fixedSize = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()) + 8 * row.numFields();
+ int bitsetWidthInBytes = UnsafeRow.calculateBitSetWidthInBytes(row.numFields());
+ if (row.numFields() > (Integer.MAX_VALUE - initialSize - bitsetWidthInBytes) / 8) {
+ throw new UnsupportedOperationException(
+ "Cannot create BufferHolder for input UnsafeRow because there are " +
+ "too many fields (number of fields: " + row.numFields() + ")");
+ }
+ this.fixedSize = bitsetWidthInBytes + 8 * row.numFields();
this.buffer = new byte[fixedSize + initialSize];
this.row = row;
this.row.pointTo(buffer, buffer.length);
@@ -55,10 +61,16 @@ public class BufferHolder {
* Grows the buffer by at least neededSize and points the row to the buffer.
*/
public void grow(int neededSize) {
+ if (neededSize > Integer.MAX_VALUE - totalSize()) {
+ throw new UnsupportedOperationException(
+ "Cannot grow BufferHolder by size " + neededSize + " because the size after growing " +
+ "exceeds size limitation " + Integer.MAX_VALUE);
+ }
final int length = totalSize() + neededSize;
if (buffer.length < length) {
// This will not happen frequently, because the buffer is re-used.
- final byte[] tmp = new byte[length * 2];
+ int newLength = length < Integer.MAX_VALUE / 2 ? length * 2 : Integer.MAX_VALUE;
+ final byte[] tmp = new byte[newLength];
Platform.copyMemory(
buffer,
Platform.BYTE_ARRAY_OFFSET,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala
new file mode 100644
index 0000000000..c7c386b5b8
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+
+class BufferHolderSuite extends SparkFunSuite {
+
+ test("SPARK-16071 Check the size limit to avoid integer overflow") {
+ var e = intercept[UnsupportedOperationException] {
+ new BufferHolder(new UnsafeRow(Int.MaxValue / 8))
+ }
+ assert(e.getMessage.contains("too many fields"))
+
+ val holder = new BufferHolder(new UnsafeRow(1000))
+ holder.reset()
+ holder.grow(1000)
+ e = intercept[UnsupportedOperationException] {
+ holder.grow(Integer.MAX_VALUE)
+ }
+ assert(e.getMessage.contains("exceeds size limitation"))
+ }
+}