diff options
author | Davies Liu <davies@databricks.com> | 2015-08-06 09:10:57 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2015-08-06 09:10:57 -0700 |
commit | 5b965d64ee1687145ba793da749659c8f67384e8 (patch) | |
tree | a163c8545572b3270fac7159e0d2b6dba5fa4795 | |
parent | aead18ffca36830e854fba32a1cac11a0b2e31d5 (diff) | |
download | spark-5b965d64ee1687145ba793da749659c8f67384e8.tar.gz spark-5b965d64ee1687145ba793da749659c8f67384e8.tar.bz2 spark-5b965d64ee1687145ba793da749659c8f67384e8.zip |
[SPARK-9644] [SQL] Support update DecimalType with precision > 18 in UnsafeRow
In order to support update a varlength (actually fixed length) object, the space should be preserved even it's null. And, we can't call setNullAt(i) for it anymore, we because setNullAt(i) will remove the offset of the preserved space, should call setDecimal(i, null, precision) instead.
After this, we can do hash based aggregation on DecimalType with precision > 18. In a tests, this could decrease the end-to-end run time of aggregation query from 37 seconds (sort based) to 24 seconds (hash based).
cc rxin
Author: Davies Liu <davies@databricks.com>
Closes #7978 from davies/update_decimal and squashes the following commits:
bed8100 [Davies Liu] isSettable -> isMutable
923c9eb [Davies Liu] address comments and fix bug
385891d [Davies Liu] Merge branch 'master' of github.com:apache/spark into update_decimal
36a1872 [Davies Liu] fix tests
cd6c524 [Davies Liu] support set decimal with precision > 18
10 files changed, 183 insertions, 61 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index e3e1622de0..e829acb628 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -65,11 +65,11 @@ public final class UnsafeRow extends MutableRow { /** * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types) */ - public static final Set<DataType> settableFieldTypes; + public static final Set<DataType> mutableFieldTypes; - // DecimalType(precision <= 18) is settable + // DecimalType is also mutable static { - settableFieldTypes = Collections.unmodifiableSet( + mutableFieldTypes = Collections.unmodifiableSet( new HashSet<>( Arrays.asList(new DataType[] { NullType, @@ -87,12 +87,16 @@ public final class UnsafeRow extends MutableRow { public static boolean isFixedLength(DataType dt) { if (dt instanceof DecimalType) { - return ((DecimalType) dt).precision() < Decimal.MAX_LONG_DIGITS(); + return ((DecimalType) dt).precision() <= Decimal.MAX_LONG_DIGITS(); } else { - return settableFieldTypes.contains(dt); + return mutableFieldTypes.contains(dt); } } + public static boolean isMutable(DataType dt) { + return mutableFieldTypes.contains(dt) || dt instanceof DecimalType; + } + ////////////////////////////////////////////////////////////////////////////// // Private fields and methods ////////////////////////////////////////////////////////////////////////////// @@ -238,17 +242,45 @@ public final class UnsafeRow extends MutableRow { PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); } + /** + * Updates the decimal column. + * + * Note: In order to support update a decimal with precision > 18, CAN NOT call + * setNullAt() for this column. + */ @Override public void setDecimal(int ordinal, Decimal value, int precision) { assertIndexIsValid(ordinal); - if (value == null) { - setNullAt(ordinal); - } else { - if (precision <= Decimal.MAX_LONG_DIGITS()) { + if (precision <= Decimal.MAX_LONG_DIGITS()) { + // compact format + if (value == null) { + setNullAt(ordinal); + } else { setLong(ordinal, value.toUnscaledLong()); + } + } else { + // fixed length + long cursor = getLong(ordinal) >>> 32; + assert cursor > 0 : "invalid cursor " + cursor; + // zero-out the bytes + PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + cursor, 0L); + PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + cursor + 8, 0L); + + if (value == null) { + setNullAt(ordinal); + // keep the offset for future update + PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(ordinal), cursor << 32); } else { - // TODO(davies): support update decimal (hold a bounded space even it's null) - throw new UnsupportedOperationException(); + + final BigInteger integer = value.toJavaBigDecimal().unscaledValue(); + final int[] mag = (int[]) PlatformDependent.UNSAFE.getObjectVolatile(integer, + PlatformDependent.BIG_INTEGER_MAG_OFFSET); + assert(mag.length <= 4); + + // Write the bytes to the variable length portion. + PlatformDependent.copyMemory(mag, PlatformDependent.INT_ARRAY_OFFSET, + baseObject, baseOffset + cursor, mag.length * 4); + setLong(ordinal, (cursor << 32) | ((long) (((integer.signum() + 1) << 8) + mag.length))); } } } @@ -343,6 +375,8 @@ public final class UnsafeRow extends MutableRow { return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal)); } + private static byte[] EMPTY = new byte[0]; + @Override public Decimal getDecimal(int ordinal, int precision, int scale) { if (isNullAt(ordinal)) { @@ -351,10 +385,20 @@ public final class UnsafeRow extends MutableRow { if (precision <= Decimal.MAX_LONG_DIGITS()) { return Decimal.apply(getLong(ordinal), precision, scale); } else { - byte[] bytes = getBinary(ordinal); - BigInteger bigInteger = new BigInteger(bytes); - BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); - return Decimal.apply(new scala.math.BigDecimal(javaDecimal), precision, scale); + long offsetAndSize = getLong(ordinal); + long offset = offsetAndSize >>> 32; + int signum = ((int) (offsetAndSize & 0xfff) >> 8); + assert signum >=0 && signum <= 2 : "invalid signum " + signum; + int size = (int) (offsetAndSize & 0xff); + int[] mag = new int[size]; + PlatformDependent.copyMemory(baseObject, baseOffset + offset, + mag, PlatformDependent.INT_ARRAY_OFFSET, size * 4); + + // create a BigInteger using signum and mag + BigInteger v = new BigInteger(0, EMPTY); // create the initial object + PlatformDependent.UNSAFE.putInt(v, PlatformDependent.BIG_INTEGER_SIGNUM_OFFSET, signum - 1); + PlatformDependent.UNSAFE.putObjectVolatile(v, PlatformDependent.BIG_INTEGER_MAG_OFFSET, mag); + return Decimal.apply(new BigDecimal(v, scale), precision, scale); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java index 3192873154..28e7ec0a0f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.expressions; +import java.math.BigInteger; + import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.Decimal; -import org.apache.spark.sql.types.MapData; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.ByteArray; @@ -47,29 +48,41 @@ public class UnsafeRowWriters { /** Writer for Decimal with precision larger than 18. */ public static class DecimalWriter { - + private static final int SIZE = 16; public static int getSize(Decimal input) { // bounded size - return 16; + return SIZE; } public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) { + final Object base = target.getBaseObject(); final long offset = target.getBaseOffset() + cursor; - final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); - final int numBytes = bytes.length; - assert(numBytes <= 16); - // zero-out the bytes - PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, 0L); - PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, 0L); + PlatformDependent.UNSAFE.putLong(base, offset, 0L); + PlatformDependent.UNSAFE.putLong(base, offset + 8, 0L); + + if (input == null) { + target.setNullAt(ordinal); + // keep the offset and length for update + int fieldOffset = UnsafeRow.calculateBitSetWidthInBytes(target.numFields()) + ordinal * 8; + PlatformDependent.UNSAFE.putLong(base, target.getBaseOffset() + fieldOffset, + ((long) cursor) << 32); + return SIZE; + } - // Write the bytes to the variable length portion. - PlatformDependent.copyMemory(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, - target.getBaseObject(), offset, numBytes); + final BigInteger integer = input.toJavaBigDecimal().unscaledValue(); + int signum = integer.signum() + 1; + final int[] mag = (int[]) PlatformDependent.UNSAFE.getObjectVolatile(integer, + PlatformDependent.BIG_INTEGER_MAG_OFFSET); + assert(mag.length <= 4); + // Write the bytes to the variable length portion. + PlatformDependent.copyMemory(mag, PlatformDependent.INT_ARRAY_OFFSET, + base, target.getBaseOffset() + cursor, mag.length * 4); // Set the fixed length portion. - target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); - return 16; + target.setLong(ordinal, (((long) cursor) << 32) | ((long) ((signum << 8) + mag.length))); + + return SIZE; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index e4a8fc24da..ac58423cd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp +import org.apache.spark.sql.types.DecimalType // MutableProjection is not accessible in Java abstract class BaseMutableProjection extends MutableProjection @@ -43,14 +44,26 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu case (NoOp, _) => "" case (e, i) => val evaluationCode = e.gen(ctx) - evaluationCode.code + + if (e.dataType.isInstanceOf[DecimalType]) { + // Can't call setNullAt on DecimalType, because we need to keep the offset s""" + ${evaluationCode.code} + if (${evaluationCode.isNull}) { + ${ctx.setColumn("mutableRow", e.dataType, i, null)}; + } else { + ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; + } + """ + } else { + s""" + ${evaluationCode.code} if (${evaluationCode.isNull}) { mutableRow.setNullAt($i); } else { ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; } """ + } } // collect projections into blocks as function has 64kb codesize limit in JVM val projectionBlocks = new ArrayBuffer[String]() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 71f8ea09f0..d8912df694 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -45,10 +45,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = dataType match { + case NullType => true case t: AtomicType => true case _: CalendarIntervalType => true case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) - case NullType => true case t: ArrayType if canSupport(t.elementType) => true case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true case _ => false @@ -56,7 +56,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match { case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => - s" + (${ev.isNull} ? 0 : $DecimalWriter.getSize(${ev.primitive}))" + s" + $DecimalWriter.getSize(${ev.primitive})" case StringType => s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))" case BinaryType => @@ -76,41 +76,41 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodeGenContext, fieldType: DataType, ev: GeneratedExpressionCode, - primitive: String, + target: String, index: Int, cursor: String): String = fieldType match { case _ if ctx.isPrimitiveType(fieldType) => - s"${ctx.setColumn(primitive, fieldType, index, ev.primitive)}" + s"${ctx.setColumn(target, fieldType, index, ev.primitive)}" case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => s""" // make sure Decimal object has the same scale as DecimalType if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) { - $CompactDecimalWriter.write($primitive, $index, $cursor, ${ev.primitive}); + $CompactDecimalWriter.write($target, $index, $cursor, ${ev.primitive}); } else { - $primitive.setNullAt($index); + $target.setNullAt($index); } """ case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => s""" // make sure Decimal object has the same scale as DecimalType if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) { - $cursor += $DecimalWriter.write($primitive, $index, $cursor, ${ev.primitive}); + $cursor += $DecimalWriter.write($target, $index, $cursor, ${ev.primitive}); } else { - $primitive.setNullAt($index); + $cursor += $DecimalWriter.write($target, $index, $cursor, null); } """ case StringType => - s"$cursor += $StringWriter.write($primitive, $index, $cursor, ${ev.primitive})" + s"$cursor += $StringWriter.write($target, $index, $cursor, ${ev.primitive})" case BinaryType => - s"$cursor += $BinaryWriter.write($primitive, $index, $cursor, ${ev.primitive})" + s"$cursor += $BinaryWriter.write($target, $index, $cursor, ${ev.primitive})" case CalendarIntervalType => - s"$cursor += $IntervalWriter.write($primitive, $index, $cursor, ${ev.primitive})" + s"$cursor += $IntervalWriter.write($target, $index, $cursor, ${ev.primitive})" case _: StructType => - s"$cursor += $StructWriter.write($primitive, $index, $cursor, ${ev.primitive})" + s"$cursor += $StructWriter.write($target, $index, $cursor, ${ev.primitive})" case _: ArrayType => - s"$cursor += $ArrayWriter.write($primitive, $index, $cursor, ${ev.primitive})" + s"$cursor += $ArrayWriter.write($target, $index, $cursor, ${ev.primitive})" case _: MapType => - s"$cursor += $MapWriter.write($primitive, $index, $cursor, ${ev.primitive})" + s"$cursor += $MapWriter.write($target, $index, $cursor, ${ev.primitive})" case NullType => "" case _ => throw new UnsupportedOperationException(s"Not supported DataType: $fieldType") @@ -146,13 +146,24 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val fieldWriters = inputTypes.zip(convertedFields).zipWithIndex.map { case ((dt, ev), i) => val update = genFieldWriter(ctx, dt, ev, output, i, cursor) - s""" - if (${ev.isNull}) { - $output.setNullAt($i); - } else { - $update; - } - """ + if (dt.isInstanceOf[DecimalType]) { + // Can't call setNullAt() for DecimalType + s""" + if (${ev.isNull}) { + $cursor += $DecimalWriter.write($output, $i, $cursor, null); + } else { + $update; + } + """ + } else { + s""" + if (${ev.isNull}) { + $output.setNullAt($i); + } else { + $update; + } + """ + } }.mkString("\n") val code = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 5e5de1d1dc..7657fb535d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** * An extended interface to [[InternalRow]] that allows the values for each column to be updated. @@ -39,6 +38,13 @@ abstract class MutableRow extends InternalRow { def setLong(i: Int, value: Long): Unit = { update(i, value) } def setFloat(i: Int, value: Float): Unit = { update(i, value) } def setDouble(i: Int, value: Double): Unit = { update(i, value) } + + /** + * Update the decimal column at `i`. + * + * Note: In order to support update decimal with precision > 18 in UnsafeRow, + * CAN NOT call setNullAt() for decimal column on UnsafeRow, call setDecimal(i, null, precision). + */ def setDecimal(i: Int, value: Decimal, precision: Int) { update(i, value) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 59491c5ba1..8c72203193 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -123,7 +123,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { DoubleType, StringType, BinaryType, - DecimalType.USER_DEFAULT + DecimalType.USER_DEFAULT, + DecimalType.SYSTEM_DEFAULT // ArrayType(IntegerType) ) val converter = UnsafeProjection.create(fieldTypes) @@ -151,6 +152,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(createdFromNull.getUTF8String(8) === null) assert(createdFromNull.getBinary(9) === null) assert(createdFromNull.getDecimal(10, 10, 0) === null) + assert(createdFromNull.getDecimal(11, 38, 18) === null) // assert(createdFromNull.get(11) === null) // If we have an UnsafeRow with columns that are initially non-null and we null out those @@ -169,6 +171,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { r.update(8, UTF8String.fromString("hello")) r.update(9, "world".getBytes) r.setDecimal(10, Decimal(10), 10) + r.setDecimal(11, Decimal(10.00, 38, 18), 38) // r.update(11, Array(11)) r } @@ -187,10 +190,17 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(setToNullAfterCreation.getBinary(9) === rowWithNoNullColumns.getBinary(9)) assert(setToNullAfterCreation.getDecimal(10, 10, 0) === rowWithNoNullColumns.getDecimal(10, 10, 0)) + assert(setToNullAfterCreation.getDecimal(11, 38, 18) === + rowWithNoNullColumns.getDecimal(11, 38, 18)) // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) for (i <- fieldTypes.indices) { - setToNullAfterCreation.setNullAt(i) + // Cann't call setNullAt() on DecimalType + if (i == 11) { + setToNullAfterCreation.setDecimal(11, null, 38) + } else { + setToNullAfterCreation.setNullAt(i) + } } // There are some garbage left in the var-length area assert(Arrays.equals(createdFromNull.getBytes, setToNullAfterCreation.getBytes())) @@ -206,6 +216,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // setToNullAfterCreation.update(8, UTF8String.fromString("hello")) // setToNullAfterCreation.update(9, "world".getBytes) setToNullAfterCreation.setDecimal(10, Decimal(10), 10) + setToNullAfterCreation.setDecimal(11, Decimal(10.00, 38, 18), 38) // setToNullAfterCreation.update(11, Array(11)) assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) @@ -220,6 +231,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) assert(setToNullAfterCreation.getDecimal(10, 10, 0) === rowWithNoNullColumns.getDecimal(10, 10, 0)) + assert(setToNullAfterCreation.getDecimal(11, 38, 18) === + rowWithNoNullColumns.getDecimal(11, 38, 18)) // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 43d06ce9bd..02458030b0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -72,7 +72,7 @@ public final class UnsafeFixedWidthAggregationMap { */ public static boolean supportsAggregationBufferSchema(StructType schema) { for (StructField field: schema.fields()) { - if (!UnsafeRow.isFixedLength(field.dataType())) { + if (!UnsafeRow.isMutable(field.dataType())) { return false; } } @@ -111,8 +111,6 @@ public final class UnsafeFixedWidthAggregationMap { // Initialize the buffer for aggregation value final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema); this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); - assert(this.emptyAggregationBuffer.length == aggregationBufferSchema.length() * 8 + - UnsafeRow.calculateBitSetWidthInBytes(aggregationBufferSchema.length())); } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 78bcee16c9..40f6bff53d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2} -import org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap -import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.KVIterator /** @@ -57,7 +55,7 @@ class SortBasedAggregationIterator( val bufferRowSize: Int = bufferSchema.length val genericMutableBuffer = new GenericMutableRow(bufferRowSize) - val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isFixedLength) + val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) val buffer = if (useUnsafeBuffer) { val unsafeProjection = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index b513c970cc..e03473041c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -93,7 +93,7 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { testWithMemoryLeakDetection("supported schemas") { assert(supportsAggregationBufferSchema( StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil))) - assert(!supportsAggregationBufferSchema( + assert(supportsAggregationBufferSchema( StructType(StructField("x", DecimalType.SYSTEM_DEFAULT) :: Nil))) assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) assert( diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java index 192c6714b2..b2de2a2590 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java @@ -18,6 +18,7 @@ package org.apache.spark.unsafe; import java.lang.reflect.Field; +import java.math.BigInteger; import sun.misc.Unsafe; @@ -87,6 +88,14 @@ public final class PlatformDependent { _UNSAFE.putDouble(object, offset, value); } + public static Object getObjectVolatile(Object object, long offset) { + return _UNSAFE.getObjectVolatile(object, offset); + } + + public static void putObjectVolatile(Object object, long offset, Object value) { + _UNSAFE.putObjectVolatile(object, offset, value); + } + public static long allocateMemory(long size) { return _UNSAFE.allocateMemory(size); } @@ -107,6 +116,10 @@ public final class PlatformDependent { public static final int DOUBLE_ARRAY_OFFSET; + // Support for resetting final fields while deserializing + public static final long BIG_INTEGER_SIGNUM_OFFSET; + public static final long BIG_INTEGER_MAG_OFFSET; + /** * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to * allow safepoint polling during a large copy. @@ -129,11 +142,24 @@ public final class PlatformDependent { INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class); LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class); DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class); + + long signumOffset = 0; + long magOffset = 0; + try { + signumOffset = _UNSAFE.objectFieldOffset(BigInteger.class.getDeclaredField("signum")); + magOffset = _UNSAFE.objectFieldOffset(BigInteger.class.getDeclaredField("mag")); + } catch (Exception ex) { + // should not happen + } + BIG_INTEGER_SIGNUM_OFFSET = signumOffset; + BIG_INTEGER_MAG_OFFSET = magOffset; } else { BYTE_ARRAY_OFFSET = 0; INT_ARRAY_OFFSET = 0; LONG_ARRAY_OFFSET = 0; DOUBLE_ARRAY_OFFSET = 0; + BIG_INTEGER_SIGNUM_OFFSET = 0; + BIG_INTEGER_MAG_OFFSET = 0; } } |