aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-08-06 09:10:57 -0700
committerDavies Liu <davies.liu@gmail.com>2015-08-06 09:10:57 -0700
commit5b965d64ee1687145ba793da749659c8f67384e8 (patch)
treea163c8545572b3270fac7159e0d2b6dba5fa4795
parentaead18ffca36830e854fba32a1cac11a0b2e31d5 (diff)
downloadspark-5b965d64ee1687145ba793da749659c8f67384e8.tar.gz
spark-5b965d64ee1687145ba793da749659c8f67384e8.tar.bz2
spark-5b965d64ee1687145ba793da749659c8f67384e8.zip
[SPARK-9644] [SQL] Support update DecimalType with precision > 18 in UnsafeRow
In order to support update a varlength (actually fixed length) object, the space should be preserved even it's null. And, we can't call setNullAt(i) for it anymore, we because setNullAt(i) will remove the offset of the preserved space, should call setDecimal(i, null, precision) instead. After this, we can do hash based aggregation on DecimalType with precision > 18. In a tests, this could decrease the end-to-end run time of aggregation query from 37 seconds (sort based) to 24 seconds (hash based). cc rxin Author: Davies Liu <davies@databricks.com> Closes #7978 from davies/update_decimal and squashes the following commits: bed8100 [Davies Liu] isSettable -> isMutable 923c9eb [Davies Liu] address comments and fix bug 385891d [Davies Liu] Merge branch 'master' of github.com:apache/spark into update_decimal 36a1872 [Davies Liu] fix tests cd6c524 [Davies Liu] support set decimal with precision > 18
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java74
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java41
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala15
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala53
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala8
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala17
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala2
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java26
10 files changed, 183 insertions, 61 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index e3e1622de0..e829acb628 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -65,11 +65,11 @@ public final class UnsafeRow extends MutableRow {
/**
* Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types)
*/
- public static final Set<DataType> settableFieldTypes;
+ public static final Set<DataType> mutableFieldTypes;
- // DecimalType(precision <= 18) is settable
+ // DecimalType is also mutable
static {
- settableFieldTypes = Collections.unmodifiableSet(
+ mutableFieldTypes = Collections.unmodifiableSet(
new HashSet<>(
Arrays.asList(new DataType[] {
NullType,
@@ -87,12 +87,16 @@ public final class UnsafeRow extends MutableRow {
public static boolean isFixedLength(DataType dt) {
if (dt instanceof DecimalType) {
- return ((DecimalType) dt).precision() < Decimal.MAX_LONG_DIGITS();
+ return ((DecimalType) dt).precision() <= Decimal.MAX_LONG_DIGITS();
} else {
- return settableFieldTypes.contains(dt);
+ return mutableFieldTypes.contains(dt);
}
}
+ public static boolean isMutable(DataType dt) {
+ return mutableFieldTypes.contains(dt) || dt instanceof DecimalType;
+ }
+
//////////////////////////////////////////////////////////////////////////////
// Private fields and methods
//////////////////////////////////////////////////////////////////////////////
@@ -238,17 +242,45 @@ public final class UnsafeRow extends MutableRow {
PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value);
}
+ /**
+ * Updates the decimal column.
+ *
+ * Note: In order to support update a decimal with precision > 18, CAN NOT call
+ * setNullAt() for this column.
+ */
@Override
public void setDecimal(int ordinal, Decimal value, int precision) {
assertIndexIsValid(ordinal);
- if (value == null) {
- setNullAt(ordinal);
- } else {
- if (precision <= Decimal.MAX_LONG_DIGITS()) {
+ if (precision <= Decimal.MAX_LONG_DIGITS()) {
+ // compact format
+ if (value == null) {
+ setNullAt(ordinal);
+ } else {
setLong(ordinal, value.toUnscaledLong());
+ }
+ } else {
+ // fixed length
+ long cursor = getLong(ordinal) >>> 32;
+ assert cursor > 0 : "invalid cursor " + cursor;
+ // zero-out the bytes
+ PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + cursor, 0L);
+ PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + cursor + 8, 0L);
+
+ if (value == null) {
+ setNullAt(ordinal);
+ // keep the offset for future update
+ PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(ordinal), cursor << 32);
} else {
- // TODO(davies): support update decimal (hold a bounded space even it's null)
- throw new UnsupportedOperationException();
+
+ final BigInteger integer = value.toJavaBigDecimal().unscaledValue();
+ final int[] mag = (int[]) PlatformDependent.UNSAFE.getObjectVolatile(integer,
+ PlatformDependent.BIG_INTEGER_MAG_OFFSET);
+ assert(mag.length <= 4);
+
+ // Write the bytes to the variable length portion.
+ PlatformDependent.copyMemory(mag, PlatformDependent.INT_ARRAY_OFFSET,
+ baseObject, baseOffset + cursor, mag.length * 4);
+ setLong(ordinal, (cursor << 32) | ((long) (((integer.signum() + 1) << 8) + mag.length)));
}
}
}
@@ -343,6 +375,8 @@ public final class UnsafeRow extends MutableRow {
return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal));
}
+ private static byte[] EMPTY = new byte[0];
+
@Override
public Decimal getDecimal(int ordinal, int precision, int scale) {
if (isNullAt(ordinal)) {
@@ -351,10 +385,20 @@ public final class UnsafeRow extends MutableRow {
if (precision <= Decimal.MAX_LONG_DIGITS()) {
return Decimal.apply(getLong(ordinal), precision, scale);
} else {
- byte[] bytes = getBinary(ordinal);
- BigInteger bigInteger = new BigInteger(bytes);
- BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
- return Decimal.apply(new scala.math.BigDecimal(javaDecimal), precision, scale);
+ long offsetAndSize = getLong(ordinal);
+ long offset = offsetAndSize >>> 32;
+ int signum = ((int) (offsetAndSize & 0xfff) >> 8);
+ assert signum >=0 && signum <= 2 : "invalid signum " + signum;
+ int size = (int) (offsetAndSize & 0xff);
+ int[] mag = new int[size];
+ PlatformDependent.copyMemory(baseObject, baseOffset + offset,
+ mag, PlatformDependent.INT_ARRAY_OFFSET, size * 4);
+
+ // create a BigInteger using signum and mag
+ BigInteger v = new BigInteger(0, EMPTY); // create the initial object
+ PlatformDependent.UNSAFE.putInt(v, PlatformDependent.BIG_INTEGER_SIGNUM_OFFSET, signum - 1);
+ PlatformDependent.UNSAFE.putObjectVolatile(v, PlatformDependent.BIG_INTEGER_MAG_OFFSET, mag);
+ return Decimal.apply(new BigDecimal(v, scale), precision, scale);
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
index 3192873154..28e7ec0a0f 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
@@ -17,9 +17,10 @@
package org.apache.spark.sql.catalyst.expressions;
+import java.math.BigInteger;
+
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.Decimal;
-import org.apache.spark.sql.types.MapData;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.types.ByteArray;
@@ -47,29 +48,41 @@ public class UnsafeRowWriters {
/** Writer for Decimal with precision larger than 18. */
public static class DecimalWriter {
-
+ private static final int SIZE = 16;
public static int getSize(Decimal input) {
// bounded size
- return 16;
+ return SIZE;
}
public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) {
+ final Object base = target.getBaseObject();
final long offset = target.getBaseOffset() + cursor;
- final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
- final int numBytes = bytes.length;
- assert(numBytes <= 16);
-
// zero-out the bytes
- PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, 0L);
- PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, 0L);
+ PlatformDependent.UNSAFE.putLong(base, offset, 0L);
+ PlatformDependent.UNSAFE.putLong(base, offset + 8, 0L);
+
+ if (input == null) {
+ target.setNullAt(ordinal);
+ // keep the offset and length for update
+ int fieldOffset = UnsafeRow.calculateBitSetWidthInBytes(target.numFields()) + ordinal * 8;
+ PlatformDependent.UNSAFE.putLong(base, target.getBaseOffset() + fieldOffset,
+ ((long) cursor) << 32);
+ return SIZE;
+ }
- // Write the bytes to the variable length portion.
- PlatformDependent.copyMemory(bytes, PlatformDependent.BYTE_ARRAY_OFFSET,
- target.getBaseObject(), offset, numBytes);
+ final BigInteger integer = input.toJavaBigDecimal().unscaledValue();
+ int signum = integer.signum() + 1;
+ final int[] mag = (int[]) PlatformDependent.UNSAFE.getObjectVolatile(integer,
+ PlatformDependent.BIG_INTEGER_MAG_OFFSET);
+ assert(mag.length <= 4);
+ // Write the bytes to the variable length portion.
+ PlatformDependent.copyMemory(mag, PlatformDependent.INT_ARRAY_OFFSET,
+ base, target.getBaseOffset() + cursor, mag.length * 4);
// Set the fixed length portion.
- target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes));
- return 16;
+ target.setLong(ordinal, (((long) cursor) << 32) | ((long) ((signum << 8) + mag.length)));
+
+ return SIZE;
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index e4a8fc24da..ac58423cd8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
+import org.apache.spark.sql.types.DecimalType
// MutableProjection is not accessible in Java
abstract class BaseMutableProjection extends MutableProjection
@@ -43,14 +44,26 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
case (NoOp, _) => ""
case (e, i) =>
val evaluationCode = e.gen(ctx)
- evaluationCode.code +
+ if (e.dataType.isInstanceOf[DecimalType]) {
+ // Can't call setNullAt on DecimalType, because we need to keep the offset
s"""
+ ${evaluationCode.code}
+ if (${evaluationCode.isNull}) {
+ ${ctx.setColumn("mutableRow", e.dataType, i, null)};
+ } else {
+ ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)};
+ }
+ """
+ } else {
+ s"""
+ ${evaluationCode.code}
if (${evaluationCode.isNull}) {
mutableRow.setNullAt($i);
} else {
${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)};
}
"""
+ }
}
// collect projections into blocks as function has 64kb codesize limit in JVM
val projectionBlocks = new ArrayBuffer[String]()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 71f8ea09f0..d8912df694 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -45,10 +45,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
/** Returns true iff we support this data type. */
def canSupport(dataType: DataType): Boolean = dataType match {
+ case NullType => true
case t: AtomicType => true
case _: CalendarIntervalType => true
case t: StructType => t.toSeq.forall(field => canSupport(field.dataType))
- case NullType => true
case t: ArrayType if canSupport(t.elementType) => true
case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true
case _ => false
@@ -56,7 +56,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match {
case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
- s" + (${ev.isNull} ? 0 : $DecimalWriter.getSize(${ev.primitive}))"
+ s" + $DecimalWriter.getSize(${ev.primitive})"
case StringType =>
s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))"
case BinaryType =>
@@ -76,41 +76,41 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
ctx: CodeGenContext,
fieldType: DataType,
ev: GeneratedExpressionCode,
- primitive: String,
+ target: String,
index: Int,
cursor: String): String = fieldType match {
case _ if ctx.isPrimitiveType(fieldType) =>
- s"${ctx.setColumn(primitive, fieldType, index, ev.primitive)}"
+ s"${ctx.setColumn(target, fieldType, index, ev.primitive)}"
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
s"""
// make sure Decimal object has the same scale as DecimalType
if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) {
- $CompactDecimalWriter.write($primitive, $index, $cursor, ${ev.primitive});
+ $CompactDecimalWriter.write($target, $index, $cursor, ${ev.primitive});
} else {
- $primitive.setNullAt($index);
+ $target.setNullAt($index);
}
"""
case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
s"""
// make sure Decimal object has the same scale as DecimalType
if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) {
- $cursor += $DecimalWriter.write($primitive, $index, $cursor, ${ev.primitive});
+ $cursor += $DecimalWriter.write($target, $index, $cursor, ${ev.primitive});
} else {
- $primitive.setNullAt($index);
+ $cursor += $DecimalWriter.write($target, $index, $cursor, null);
}
"""
case StringType =>
- s"$cursor += $StringWriter.write($primitive, $index, $cursor, ${ev.primitive})"
+ s"$cursor += $StringWriter.write($target, $index, $cursor, ${ev.primitive})"
case BinaryType =>
- s"$cursor += $BinaryWriter.write($primitive, $index, $cursor, ${ev.primitive})"
+ s"$cursor += $BinaryWriter.write($target, $index, $cursor, ${ev.primitive})"
case CalendarIntervalType =>
- s"$cursor += $IntervalWriter.write($primitive, $index, $cursor, ${ev.primitive})"
+ s"$cursor += $IntervalWriter.write($target, $index, $cursor, ${ev.primitive})"
case _: StructType =>
- s"$cursor += $StructWriter.write($primitive, $index, $cursor, ${ev.primitive})"
+ s"$cursor += $StructWriter.write($target, $index, $cursor, ${ev.primitive})"
case _: ArrayType =>
- s"$cursor += $ArrayWriter.write($primitive, $index, $cursor, ${ev.primitive})"
+ s"$cursor += $ArrayWriter.write($target, $index, $cursor, ${ev.primitive})"
case _: MapType =>
- s"$cursor += $MapWriter.write($primitive, $index, $cursor, ${ev.primitive})"
+ s"$cursor += $MapWriter.write($target, $index, $cursor, ${ev.primitive})"
case NullType => ""
case _ =>
throw new UnsupportedOperationException(s"Not supported DataType: $fieldType")
@@ -146,13 +146,24 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val fieldWriters = inputTypes.zip(convertedFields).zipWithIndex.map { case ((dt, ev), i) =>
val update = genFieldWriter(ctx, dt, ev, output, i, cursor)
- s"""
- if (${ev.isNull}) {
- $output.setNullAt($i);
- } else {
- $update;
- }
- """
+ if (dt.isInstanceOf[DecimalType]) {
+ // Can't call setNullAt() for DecimalType
+ s"""
+ if (${ev.isNull}) {
+ $cursor += $DecimalWriter.write($output, $i, $cursor, null);
+ } else {
+ $update;
+ }
+ """
+ } else {
+ s"""
+ if (${ev.isNull}) {
+ $output.setNullAt($i);
+ } else {
+ $update;
+ }
+ """
+ }
}.mkString("\n")
val code = s"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 5e5de1d1dc..7657fb535d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
/**
* An extended interface to [[InternalRow]] that allows the values for each column to be updated.
@@ -39,6 +38,13 @@ abstract class MutableRow extends InternalRow {
def setLong(i: Int, value: Long): Unit = { update(i, value) }
def setFloat(i: Int, value: Float): Unit = { update(i, value) }
def setDouble(i: Int, value: Double): Unit = { update(i, value) }
+
+ /**
+ * Update the decimal column at `i`.
+ *
+ * Note: In order to support update decimal with precision > 18 in UnsafeRow,
+ * CAN NOT call setNullAt() for decimal column on UnsafeRow, call setDecimal(i, null, precision).
+ */
def setDecimal(i: Int, value: Decimal, precision: Int) { update(i, value) }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index 59491c5ba1..8c72203193 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -123,7 +123,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
DoubleType,
StringType,
BinaryType,
- DecimalType.USER_DEFAULT
+ DecimalType.USER_DEFAULT,
+ DecimalType.SYSTEM_DEFAULT
// ArrayType(IntegerType)
)
val converter = UnsafeProjection.create(fieldTypes)
@@ -151,6 +152,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(createdFromNull.getUTF8String(8) === null)
assert(createdFromNull.getBinary(9) === null)
assert(createdFromNull.getDecimal(10, 10, 0) === null)
+ assert(createdFromNull.getDecimal(11, 38, 18) === null)
// assert(createdFromNull.get(11) === null)
// If we have an UnsafeRow with columns that are initially non-null and we null out those
@@ -169,6 +171,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
r.update(8, UTF8String.fromString("hello"))
r.update(9, "world".getBytes)
r.setDecimal(10, Decimal(10), 10)
+ r.setDecimal(11, Decimal(10.00, 38, 18), 38)
// r.update(11, Array(11))
r
}
@@ -187,10 +190,17 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(setToNullAfterCreation.getBinary(9) === rowWithNoNullColumns.getBinary(9))
assert(setToNullAfterCreation.getDecimal(10, 10, 0) ===
rowWithNoNullColumns.getDecimal(10, 10, 0))
+ assert(setToNullAfterCreation.getDecimal(11, 38, 18) ===
+ rowWithNoNullColumns.getDecimal(11, 38, 18))
// assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
for (i <- fieldTypes.indices) {
- setToNullAfterCreation.setNullAt(i)
+ // Cann't call setNullAt() on DecimalType
+ if (i == 11) {
+ setToNullAfterCreation.setDecimal(11, null, 38)
+ } else {
+ setToNullAfterCreation.setNullAt(i)
+ }
}
// There are some garbage left in the var-length area
assert(Arrays.equals(createdFromNull.getBytes, setToNullAfterCreation.getBytes()))
@@ -206,6 +216,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
// setToNullAfterCreation.update(8, UTF8String.fromString("hello"))
// setToNullAfterCreation.update(9, "world".getBytes)
setToNullAfterCreation.setDecimal(10, Decimal(10), 10)
+ setToNullAfterCreation.setDecimal(11, Decimal(10.00, 38, 18), 38)
// setToNullAfterCreation.update(11, Array(11))
assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0))
@@ -220,6 +231,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
// assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9))
assert(setToNullAfterCreation.getDecimal(10, 10, 0) ===
rowWithNoNullColumns.getDecimal(10, 10, 0))
+ assert(setToNullAfterCreation.getDecimal(11, 38, 18) ===
+ rowWithNoNullColumns.getDecimal(11, 38, 18))
// assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 43d06ce9bd..02458030b0 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -72,7 +72,7 @@ public final class UnsafeFixedWidthAggregationMap {
*/
public static boolean supportsAggregationBufferSchema(StructType schema) {
for (StructField field: schema.fields()) {
- if (!UnsafeRow.isFixedLength(field.dataType())) {
+ if (!UnsafeRow.isMutable(field.dataType())) {
return false;
}
}
@@ -111,8 +111,6 @@ public final class UnsafeFixedWidthAggregationMap {
// Initialize the buffer for aggregation value
final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema);
this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
- assert(this.emptyAggregationBuffer.length == aggregationBufferSchema.length() * 8 +
- UnsafeRow.calculateBitSetWidthInBytes(aggregationBufferSchema.length()));
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
index 78bcee16c9..40f6bff53d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
@@ -20,8 +20,6 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2}
-import org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap
-import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.KVIterator
/**
@@ -57,7 +55,7 @@ class SortBasedAggregationIterator(
val bufferRowSize: Int = bufferSchema.length
val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
- val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isFixedLength)
+ val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable)
val buffer = if (useUnsafeBuffer) {
val unsafeProjection =
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index b513c970cc..e03473041c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -93,7 +93,7 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
testWithMemoryLeakDetection("supported schemas") {
assert(supportsAggregationBufferSchema(
StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil)))
- assert(!supportsAggregationBufferSchema(
+ assert(supportsAggregationBufferSchema(
StructType(StructField("x", DecimalType.SYSTEM_DEFAULT) :: Nil)))
assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil)))
assert(
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java
index 192c6714b2..b2de2a2590 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java
@@ -18,6 +18,7 @@
package org.apache.spark.unsafe;
import java.lang.reflect.Field;
+import java.math.BigInteger;
import sun.misc.Unsafe;
@@ -87,6 +88,14 @@ public final class PlatformDependent {
_UNSAFE.putDouble(object, offset, value);
}
+ public static Object getObjectVolatile(Object object, long offset) {
+ return _UNSAFE.getObjectVolatile(object, offset);
+ }
+
+ public static void putObjectVolatile(Object object, long offset, Object value) {
+ _UNSAFE.putObjectVolatile(object, offset, value);
+ }
+
public static long allocateMemory(long size) {
return _UNSAFE.allocateMemory(size);
}
@@ -107,6 +116,10 @@ public final class PlatformDependent {
public static final int DOUBLE_ARRAY_OFFSET;
+ // Support for resetting final fields while deserializing
+ public static final long BIG_INTEGER_SIGNUM_OFFSET;
+ public static final long BIG_INTEGER_MAG_OFFSET;
+
/**
* Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to
* allow safepoint polling during a large copy.
@@ -129,11 +142,24 @@ public final class PlatformDependent {
INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class);
LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class);
DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class);
+
+ long signumOffset = 0;
+ long magOffset = 0;
+ try {
+ signumOffset = _UNSAFE.objectFieldOffset(BigInteger.class.getDeclaredField("signum"));
+ magOffset = _UNSAFE.objectFieldOffset(BigInteger.class.getDeclaredField("mag"));
+ } catch (Exception ex) {
+ // should not happen
+ }
+ BIG_INTEGER_SIGNUM_OFFSET = signumOffset;
+ BIG_INTEGER_MAG_OFFSET = magOffset;
} else {
BYTE_ARRAY_OFFSET = 0;
INT_ARRAY_OFFSET = 0;
LONG_ARRAY_OFFSET = 0;
DOUBLE_ARRAY_OFFSET = 0;
+ BIG_INTEGER_SIGNUM_OFFSET = 0;
+ BIG_INTEGER_MAG_OFFSET = 0;
}
}