aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java74
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java41
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala15
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala53
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala8
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala17
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala2
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java26
10 files changed, 183 insertions, 61 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index e3e1622de0..e829acb628 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -65,11 +65,11 @@ public final class UnsafeRow extends MutableRow {
/**
* Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types)
*/
- public static final Set<DataType> settableFieldTypes;
+ public static final Set<DataType> mutableFieldTypes;
- // DecimalType(precision <= 18) is settable
+ // DecimalType is also mutable
static {
- settableFieldTypes = Collections.unmodifiableSet(
+ mutableFieldTypes = Collections.unmodifiableSet(
new HashSet<>(
Arrays.asList(new DataType[] {
NullType,
@@ -87,12 +87,16 @@ public final class UnsafeRow extends MutableRow {
public static boolean isFixedLength(DataType dt) {
if (dt instanceof DecimalType) {
- return ((DecimalType) dt).precision() < Decimal.MAX_LONG_DIGITS();
+ return ((DecimalType) dt).precision() <= Decimal.MAX_LONG_DIGITS();
} else {
- return settableFieldTypes.contains(dt);
+ return mutableFieldTypes.contains(dt);
}
}
+ public static boolean isMutable(DataType dt) {
+ return mutableFieldTypes.contains(dt) || dt instanceof DecimalType;
+ }
+
//////////////////////////////////////////////////////////////////////////////
// Private fields and methods
//////////////////////////////////////////////////////////////////////////////
@@ -238,17 +242,45 @@ public final class UnsafeRow extends MutableRow {
PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value);
}
+ /**
+ * Updates the decimal column.
+ *
+ * Note: In order to support update a decimal with precision > 18, CAN NOT call
+ * setNullAt() for this column.
+ */
@Override
public void setDecimal(int ordinal, Decimal value, int precision) {
assertIndexIsValid(ordinal);
- if (value == null) {
- setNullAt(ordinal);
- } else {
- if (precision <= Decimal.MAX_LONG_DIGITS()) {
+ if (precision <= Decimal.MAX_LONG_DIGITS()) {
+ // compact format
+ if (value == null) {
+ setNullAt(ordinal);
+ } else {
setLong(ordinal, value.toUnscaledLong());
+ }
+ } else {
+ // fixed length
+ long cursor = getLong(ordinal) >>> 32;
+ assert cursor > 0 : "invalid cursor " + cursor;
+ // zero-out the bytes
+ PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + cursor, 0L);
+ PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + cursor + 8, 0L);
+
+ if (value == null) {
+ setNullAt(ordinal);
+ // keep the offset for future update
+ PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(ordinal), cursor << 32);
} else {
- // TODO(davies): support update decimal (hold a bounded space even it's null)
- throw new UnsupportedOperationException();
+
+ final BigInteger integer = value.toJavaBigDecimal().unscaledValue();
+ final int[] mag = (int[]) PlatformDependent.UNSAFE.getObjectVolatile(integer,
+ PlatformDependent.BIG_INTEGER_MAG_OFFSET);
+ assert(mag.length <= 4);
+
+ // Write the bytes to the variable length portion.
+ PlatformDependent.copyMemory(mag, PlatformDependent.INT_ARRAY_OFFSET,
+ baseObject, baseOffset + cursor, mag.length * 4);
+ setLong(ordinal, (cursor << 32) | ((long) (((integer.signum() + 1) << 8) + mag.length)));
}
}
}
@@ -343,6 +375,8 @@ public final class UnsafeRow extends MutableRow {
return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal));
}
+ private static byte[] EMPTY = new byte[0];
+
@Override
public Decimal getDecimal(int ordinal, int precision, int scale) {
if (isNullAt(ordinal)) {
@@ -351,10 +385,20 @@ public final class UnsafeRow extends MutableRow {
if (precision <= Decimal.MAX_LONG_DIGITS()) {
return Decimal.apply(getLong(ordinal), precision, scale);
} else {
- byte[] bytes = getBinary(ordinal);
- BigInteger bigInteger = new BigInteger(bytes);
- BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
- return Decimal.apply(new scala.math.BigDecimal(javaDecimal), precision, scale);
+ long offsetAndSize = getLong(ordinal);
+ long offset = offsetAndSize >>> 32;
+ int signum = ((int) (offsetAndSize & 0xfff) >> 8);
+ assert signum >=0 && signum <= 2 : "invalid signum " + signum;
+ int size = (int) (offsetAndSize & 0xff);
+ int[] mag = new int[size];
+ PlatformDependent.copyMemory(baseObject, baseOffset + offset,
+ mag, PlatformDependent.INT_ARRAY_OFFSET, size * 4);
+
+ // create a BigInteger using signum and mag
+ BigInteger v = new BigInteger(0, EMPTY); // create the initial object
+ PlatformDependent.UNSAFE.putInt(v, PlatformDependent.BIG_INTEGER_SIGNUM_OFFSET, signum - 1);
+ PlatformDependent.UNSAFE.putObjectVolatile(v, PlatformDependent.BIG_INTEGER_MAG_OFFSET, mag);
+ return Decimal.apply(new BigDecimal(v, scale), precision, scale);
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
index 3192873154..28e7ec0a0f 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
@@ -17,9 +17,10 @@
package org.apache.spark.sql.catalyst.expressions;
+import java.math.BigInteger;
+
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.Decimal;
-import org.apache.spark.sql.types.MapData;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.types.ByteArray;
@@ -47,29 +48,41 @@ public class UnsafeRowWriters {
/** Writer for Decimal with precision larger than 18. */
public static class DecimalWriter {
-
+ private static final int SIZE = 16;
public static int getSize(Decimal input) {
// bounded size
- return 16;
+ return SIZE;
}
public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) {
+ final Object base = target.getBaseObject();
final long offset = target.getBaseOffset() + cursor;
- final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
- final int numBytes = bytes.length;
- assert(numBytes <= 16);
-
// zero-out the bytes
- PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, 0L);
- PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, 0L);
+ PlatformDependent.UNSAFE.putLong(base, offset, 0L);
+ PlatformDependent.UNSAFE.putLong(base, offset + 8, 0L);
+
+ if (input == null) {
+ target.setNullAt(ordinal);
+ // keep the offset and length for update
+ int fieldOffset = UnsafeRow.calculateBitSetWidthInBytes(target.numFields()) + ordinal * 8;
+ PlatformDependent.UNSAFE.putLong(base, target.getBaseOffset() + fieldOffset,
+ ((long) cursor) << 32);
+ return SIZE;
+ }
- // Write the bytes to the variable length portion.
- PlatformDependent.copyMemory(bytes, PlatformDependent.BYTE_ARRAY_OFFSET,
- target.getBaseObject(), offset, numBytes);
+ final BigInteger integer = input.toJavaBigDecimal().unscaledValue();
+ int signum = integer.signum() + 1;
+ final int[] mag = (int[]) PlatformDependent.UNSAFE.getObjectVolatile(integer,
+ PlatformDependent.BIG_INTEGER_MAG_OFFSET);
+ assert(mag.length <= 4);
+ // Write the bytes to the variable length portion.
+ PlatformDependent.copyMemory(mag, PlatformDependent.INT_ARRAY_OFFSET,
+ base, target.getBaseOffset() + cursor, mag.length * 4);
// Set the fixed length portion.
- target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes));
- return 16;
+ target.setLong(ordinal, (((long) cursor) << 32) | ((long) ((signum << 8) + mag.length)));
+
+ return SIZE;
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index e4a8fc24da..ac58423cd8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
+import org.apache.spark.sql.types.DecimalType
// MutableProjection is not accessible in Java
abstract class BaseMutableProjection extends MutableProjection
@@ -43,14 +44,26 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
case (NoOp, _) => ""
case (e, i) =>
val evaluationCode = e.gen(ctx)
- evaluationCode.code +
+ if (e.dataType.isInstanceOf[DecimalType]) {
+ // Can't call setNullAt on DecimalType, because we need to keep the offset
s"""
+ ${evaluationCode.code}
+ if (${evaluationCode.isNull}) {
+ ${ctx.setColumn("mutableRow", e.dataType, i, null)};
+ } else {
+ ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)};
+ }
+ """
+ } else {
+ s"""
+ ${evaluationCode.code}
if (${evaluationCode.isNull}) {
mutableRow.setNullAt($i);
} else {
${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)};
}
"""
+ }
}
// collect projections into blocks as function has 64kb codesize limit in JVM
val projectionBlocks = new ArrayBuffer[String]()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 71f8ea09f0..d8912df694 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -45,10 +45,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
/** Returns true iff we support this data type. */
def canSupport(dataType: DataType): Boolean = dataType match {
+ case NullType => true
case t: AtomicType => true
case _: CalendarIntervalType => true
case t: StructType => t.toSeq.forall(field => canSupport(field.dataType))
- case NullType => true
case t: ArrayType if canSupport(t.elementType) => true
case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true
case _ => false
@@ -56,7 +56,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match {
case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
- s" + (${ev.isNull} ? 0 : $DecimalWriter.getSize(${ev.primitive}))"
+ s" + $DecimalWriter.getSize(${ev.primitive})"
case StringType =>
s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))"
case BinaryType =>
@@ -76,41 +76,41 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
ctx: CodeGenContext,
fieldType: DataType,
ev: GeneratedExpressionCode,
- primitive: String,
+ target: String,
index: Int,
cursor: String): String = fieldType match {
case _ if ctx.isPrimitiveType(fieldType) =>
- s"${ctx.setColumn(primitive, fieldType, index, ev.primitive)}"
+ s"${ctx.setColumn(target, fieldType, index, ev.primitive)}"
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
s"""
// make sure Decimal object has the same scale as DecimalType
if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) {
- $CompactDecimalWriter.write($primitive, $index, $cursor, ${ev.primitive});
+ $CompactDecimalWriter.write($target, $index, $cursor, ${ev.primitive});
} else {
- $primitive.setNullAt($index);
+ $target.setNullAt($index);
}
"""
case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
s"""
// make sure Decimal object has the same scale as DecimalType
if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) {
- $cursor += $DecimalWriter.write($primitive, $index, $cursor, ${ev.primitive});
+ $cursor += $DecimalWriter.write($target, $index, $cursor, ${ev.primitive});
} else {
- $primitive.setNullAt($index);
+ $cursor += $DecimalWriter.write($target, $index, $cursor, null);
}
"""
case StringType =>
- s"$cursor += $StringWriter.write($primitive, $index, $cursor, ${ev.primitive})"
+ s"$cursor += $StringWriter.write($target, $index, $cursor, ${ev.primitive})"
case BinaryType =>
- s"$cursor += $BinaryWriter.write($primitive, $index, $cursor, ${ev.primitive})"
+ s"$cursor += $BinaryWriter.write($target, $index, $cursor, ${ev.primitive})"
case CalendarIntervalType =>
- s"$cursor += $IntervalWriter.write($primitive, $index, $cursor, ${ev.primitive})"
+ s"$cursor += $IntervalWriter.write($target, $index, $cursor, ${ev.primitive})"
case _: StructType =>
- s"$cursor += $StructWriter.write($primitive, $index, $cursor, ${ev.primitive})"
+ s"$cursor += $StructWriter.write($target, $index, $cursor, ${ev.primitive})"
case _: ArrayType =>
- s"$cursor += $ArrayWriter.write($primitive, $index, $cursor, ${ev.primitive})"
+ s"$cursor += $ArrayWriter.write($target, $index, $cursor, ${ev.primitive})"
case _: MapType =>
- s"$cursor += $MapWriter.write($primitive, $index, $cursor, ${ev.primitive})"
+ s"$cursor += $MapWriter.write($target, $index, $cursor, ${ev.primitive})"
case NullType => ""
case _ =>
throw new UnsupportedOperationException(s"Not supported DataType: $fieldType")
@@ -146,13 +146,24 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val fieldWriters = inputTypes.zip(convertedFields).zipWithIndex.map { case ((dt, ev), i) =>
val update = genFieldWriter(ctx, dt, ev, output, i, cursor)
- s"""
- if (${ev.isNull}) {
- $output.setNullAt($i);
- } else {
- $update;
- }
- """
+ if (dt.isInstanceOf[DecimalType]) {
+ // Can't call setNullAt() for DecimalType
+ s"""
+ if (${ev.isNull}) {
+ $cursor += $DecimalWriter.write($output, $i, $cursor, null);
+ } else {
+ $update;
+ }
+ """
+ } else {
+ s"""
+ if (${ev.isNull}) {
+ $output.setNullAt($i);
+ } else {
+ $update;
+ }
+ """
+ }
}.mkString("\n")
val code = s"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 5e5de1d1dc..7657fb535d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
/**
* An extended interface to [[InternalRow]] that allows the values for each column to be updated.
@@ -39,6 +38,13 @@ abstract class MutableRow extends InternalRow {
def setLong(i: Int, value: Long): Unit = { update(i, value) }
def setFloat(i: Int, value: Float): Unit = { update(i, value) }
def setDouble(i: Int, value: Double): Unit = { update(i, value) }
+
+ /**
+ * Update the decimal column at `i`.
+ *
+ * Note: In order to support update decimal with precision > 18 in UnsafeRow,
+ * CAN NOT call setNullAt() for decimal column on UnsafeRow, call setDecimal(i, null, precision).
+ */
def setDecimal(i: Int, value: Decimal, precision: Int) { update(i, value) }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index 59491c5ba1..8c72203193 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -123,7 +123,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
DoubleType,
StringType,
BinaryType,
- DecimalType.USER_DEFAULT
+ DecimalType.USER_DEFAULT,
+ DecimalType.SYSTEM_DEFAULT
// ArrayType(IntegerType)
)
val converter = UnsafeProjection.create(fieldTypes)
@@ -151,6 +152,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(createdFromNull.getUTF8String(8) === null)
assert(createdFromNull.getBinary(9) === null)
assert(createdFromNull.getDecimal(10, 10, 0) === null)
+ assert(createdFromNull.getDecimal(11, 38, 18) === null)
// assert(createdFromNull.get(11) === null)
// If we have an UnsafeRow with columns that are initially non-null and we null out those
@@ -169,6 +171,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
r.update(8, UTF8String.fromString("hello"))
r.update(9, "world".getBytes)
r.setDecimal(10, Decimal(10), 10)
+ r.setDecimal(11, Decimal(10.00, 38, 18), 38)
// r.update(11, Array(11))
r
}
@@ -187,10 +190,17 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(setToNullAfterCreation.getBinary(9) === rowWithNoNullColumns.getBinary(9))
assert(setToNullAfterCreation.getDecimal(10, 10, 0) ===
rowWithNoNullColumns.getDecimal(10, 10, 0))
+ assert(setToNullAfterCreation.getDecimal(11, 38, 18) ===
+ rowWithNoNullColumns.getDecimal(11, 38, 18))
// assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
for (i <- fieldTypes.indices) {
- setToNullAfterCreation.setNullAt(i)
+ // Cann't call setNullAt() on DecimalType
+ if (i == 11) {
+ setToNullAfterCreation.setDecimal(11, null, 38)
+ } else {
+ setToNullAfterCreation.setNullAt(i)
+ }
}
// There are some garbage left in the var-length area
assert(Arrays.equals(createdFromNull.getBytes, setToNullAfterCreation.getBytes()))
@@ -206,6 +216,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
// setToNullAfterCreation.update(8, UTF8String.fromString("hello"))
// setToNullAfterCreation.update(9, "world".getBytes)
setToNullAfterCreation.setDecimal(10, Decimal(10), 10)
+ setToNullAfterCreation.setDecimal(11, Decimal(10.00, 38, 18), 38)
// setToNullAfterCreation.update(11, Array(11))
assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0))
@@ -220,6 +231,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
// assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9))
assert(setToNullAfterCreation.getDecimal(10, 10, 0) ===
rowWithNoNullColumns.getDecimal(10, 10, 0))
+ assert(setToNullAfterCreation.getDecimal(11, 38, 18) ===
+ rowWithNoNullColumns.getDecimal(11, 38, 18))
// assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 43d06ce9bd..02458030b0 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -72,7 +72,7 @@ public final class UnsafeFixedWidthAggregationMap {
*/
public static boolean supportsAggregationBufferSchema(StructType schema) {
for (StructField field: schema.fields()) {
- if (!UnsafeRow.isFixedLength(field.dataType())) {
+ if (!UnsafeRow.isMutable(field.dataType())) {
return false;
}
}
@@ -111,8 +111,6 @@ public final class UnsafeFixedWidthAggregationMap {
// Initialize the buffer for aggregation value
final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema);
this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
- assert(this.emptyAggregationBuffer.length == aggregationBufferSchema.length() * 8 +
- UnsafeRow.calculateBitSetWidthInBytes(aggregationBufferSchema.length()));
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
index 78bcee16c9..40f6bff53d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
@@ -20,8 +20,6 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2}
-import org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap
-import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.KVIterator
/**
@@ -57,7 +55,7 @@ class SortBasedAggregationIterator(
val bufferRowSize: Int = bufferSchema.length
val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
- val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isFixedLength)
+ val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable)
val buffer = if (useUnsafeBuffer) {
val unsafeProjection =
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index b513c970cc..e03473041c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -93,7 +93,7 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
testWithMemoryLeakDetection("supported schemas") {
assert(supportsAggregationBufferSchema(
StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil)))
- assert(!supportsAggregationBufferSchema(
+ assert(supportsAggregationBufferSchema(
StructType(StructField("x", DecimalType.SYSTEM_DEFAULT) :: Nil)))
assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil)))
assert(
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java
index 192c6714b2..b2de2a2590 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java
@@ -18,6 +18,7 @@
package org.apache.spark.unsafe;
import java.lang.reflect.Field;
+import java.math.BigInteger;
import sun.misc.Unsafe;
@@ -87,6 +88,14 @@ public final class PlatformDependent {
_UNSAFE.putDouble(object, offset, value);
}
+ public static Object getObjectVolatile(Object object, long offset) {
+ return _UNSAFE.getObjectVolatile(object, offset);
+ }
+
+ public static void putObjectVolatile(Object object, long offset, Object value) {
+ _UNSAFE.putObjectVolatile(object, offset, value);
+ }
+
public static long allocateMemory(long size) {
return _UNSAFE.allocateMemory(size);
}
@@ -107,6 +116,10 @@ public final class PlatformDependent {
public static final int DOUBLE_ARRAY_OFFSET;
+ // Support for resetting final fields while deserializing
+ public static final long BIG_INTEGER_SIGNUM_OFFSET;
+ public static final long BIG_INTEGER_MAG_OFFSET;
+
/**
* Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to
* allow safepoint polling during a large copy.
@@ -129,11 +142,24 @@ public final class PlatformDependent {
INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class);
LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class);
DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class);
+
+ long signumOffset = 0;
+ long magOffset = 0;
+ try {
+ signumOffset = _UNSAFE.objectFieldOffset(BigInteger.class.getDeclaredField("signum"));
+ magOffset = _UNSAFE.objectFieldOffset(BigInteger.class.getDeclaredField("mag"));
+ } catch (Exception ex) {
+ // should not happen
+ }
+ BIG_INTEGER_SIGNUM_OFFSET = signumOffset;
+ BIG_INTEGER_MAG_OFFSET = magOffset;
} else {
BYTE_ARRAY_OFFSET = 0;
INT_ARRAY_OFFSET = 0;
LONG_ARRAY_OFFSET = 0;
DOUBLE_ARRAY_OFFSET = 0;
+ BIG_INTEGER_SIGNUM_OFFSET = 0;
+ BIG_INTEGER_MAG_OFFSET = 0;
}
}