aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java31
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java78
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java102
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala57
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala81
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala68
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala6
7 files changed, 291 insertions, 132 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 366615f6fe..850838af9b 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -402,7 +402,7 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
if (isNullAt(ordinal)) return null;
final long offsetAndSize = getLong(ordinal);
final int offset = (int) (offsetAndSize >> 32);
- final int size = (int) (offsetAndSize & ((1L << 32) - 1));
+ final int size = (int) offsetAndSize;
return UTF8String.fromAddress(baseObject, baseOffset + offset, size);
}
@@ -413,7 +413,7 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
} else {
final long offsetAndSize = getLong(ordinal);
final int offset = (int) (offsetAndSize >> 32);
- final int size = (int) (offsetAndSize & ((1L << 32) - 1));
+ final int size = (int) offsetAndSize;
final byte[] bytes = new byte[size];
Platform.copyMemory(
baseObject,
@@ -446,7 +446,7 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
} else {
final long offsetAndSize = getLong(ordinal);
final int offset = (int) (offsetAndSize >> 32);
- final int size = (int) (offsetAndSize & ((1L << 32) - 1));
+ final int size = (int) offsetAndSize;
final UnsafeRow row = new UnsafeRow();
row.pointTo(baseObject, baseOffset + offset, numFields, size);
return row;
@@ -460,7 +460,7 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
} else {
final long offsetAndSize = getLong(ordinal);
final int offset = (int) (offsetAndSize >> 32);
- final int size = (int) (offsetAndSize & ((1L << 32) - 1));
+ final int size = (int) offsetAndSize;
final UnsafeArrayData array = new UnsafeArrayData();
array.pointTo(baseObject, baseOffset + offset, size);
return array;
@@ -474,7 +474,7 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
} else {
final long offsetAndSize = getLong(ordinal);
final int offset = (int) (offsetAndSize >> 32);
- final int size = (int) (offsetAndSize & ((1L << 32) - 1));
+ final int size = (int) offsetAndSize;
final UnsafeMapData map = new UnsafeMapData();
map.pointTo(baseObject, baseOffset + offset, size);
return map;
@@ -618,6 +618,27 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
buffer.position(pos + sizeInBytes);
}
+ /**
+ * Write the bytes of var-length field into ByteBuffer
+ *
+ * Note: only work with HeapByteBuffer
+ */
+ public void writeFieldTo(int ordinal, ByteBuffer buffer) {
+ final long offsetAndSize = getLong(ordinal);
+ final int offset = (int) (offsetAndSize >> 32);
+ final int size = (int) offsetAndSize;
+
+ buffer.putInt(size);
+ int pos = buffer.position();
+ buffer.position(pos + size);
+ Platform.copyMemory(
+ baseObject,
+ baseOffset + offset,
+ buffer.array(),
+ Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + pos,
+ size);
+ }
+
@Override
public void writeExternal(ObjectOutput out) throws IOException {
byte[] bytes = getBytes();
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
index 7f2a1cb07a..7dd932d198 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.expressions.codegen;
-import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.types.CalendarInterval;
@@ -64,29 +63,72 @@ public class UnsafeArrayWriter {
Platform.putInt(holder.buffer, getElementOffset(ordinal), relativeOffset);
}
- public void writeCompactDecimal(int ordinal, Decimal input, int precision, int scale) {
- // make sure Decimal object has the same scale as DecimalType
- if (input.changePrecision(precision, scale)) {
- Platform.putLong(holder.buffer, holder.cursor, input.toUnscaledLong());
- setOffset(ordinal);
- holder.cursor += 8;
- } else {
- setNullAt(ordinal);
+ public void write(int ordinal, boolean value) {
+ Platform.putBoolean(holder.buffer, holder.cursor, value);
+ setOffset(ordinal);
+ holder.cursor += 1;
+ }
+
+ public void write(int ordinal, byte value) {
+ Platform.putByte(holder.buffer, holder.cursor, value);
+ setOffset(ordinal);
+ holder.cursor += 1;
+ }
+
+ public void write(int ordinal, short value) {
+ Platform.putShort(holder.buffer, holder.cursor, value);
+ setOffset(ordinal);
+ holder.cursor += 2;
+ }
+
+ public void write(int ordinal, int value) {
+ Platform.putInt(holder.buffer, holder.cursor, value);
+ setOffset(ordinal);
+ holder.cursor += 4;
+ }
+
+ public void write(int ordinal, long value) {
+ Platform.putLong(holder.buffer, holder.cursor, value);
+ setOffset(ordinal);
+ holder.cursor += 8;
+ }
+
+ public void write(int ordinal, float value) {
+ if (Float.isNaN(value)) {
+ value = Float.NaN;
+ }
+ Platform.putFloat(holder.buffer, holder.cursor, value);
+ setOffset(ordinal);
+ holder.cursor += 4;
+ }
+
+ public void write(int ordinal, double value) {
+ if (Double.isNaN(value)) {
+ value = Double.NaN;
}
+ Platform.putDouble(holder.buffer, holder.cursor, value);
+ setOffset(ordinal);
+ holder.cursor += 8;
}
public void write(int ordinal, Decimal input, int precision, int scale) {
// make sure Decimal object has the same scale as DecimalType
if (input.changePrecision(precision, scale)) {
- final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
- assert bytes.length <= 16;
- holder.grow(bytes.length);
-
- // Write the bytes to the variable length portion.
- Platform.copyMemory(
- bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length);
- setOffset(ordinal);
- holder.cursor += bytes.length;
+ if (precision <= Decimal.MAX_LONG_DIGITS()) {
+ Platform.putLong(holder.buffer, holder.cursor, input.toUnscaledLong());
+ setOffset(ordinal);
+ holder.cursor += 8;
+ } else {
+ final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
+ assert bytes.length <= 16;
+ holder.grow(bytes.length);
+
+ // Write the bytes to the variable length portion.
+ Platform.copyMemory(
+ bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length);
+ setOffset(ordinal);
+ holder.cursor += bytes.length;
+ }
} else {
setNullAt(ordinal);
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
index e1f5a05d1d..adbe262187 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
@@ -58,6 +58,10 @@ public class UnsafeRowWriter {
}
}
+ public boolean isNullAt(int ordinal) {
+ return BitSetMethods.isSet(holder.buffer, startingOffset, ordinal);
+ }
+
public void setNullAt(int ordinal) {
BitSetMethods.set(holder.buffer, startingOffset, ordinal);
Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L);
@@ -95,41 +99,75 @@ public class UnsafeRowWriter {
}
}
- public void writeCompactDecimal(int ordinal, Decimal input, int precision, int scale) {
- // make sure Decimal object has the same scale as DecimalType
- if (input.changePrecision(precision, scale)) {
- Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong());
- } else {
- setNullAt(ordinal);
- }
+ public void write(int ordinal, boolean value) {
+ Platform.putBoolean(holder.buffer, getFieldOffset(ordinal), value);
}
- public void write(int ordinal, Decimal input, int precision, int scale) {
- // grow the global buffer before writing data.
- holder.grow(16);
+ public void write(int ordinal, byte value) {
+ Platform.putByte(holder.buffer, getFieldOffset(ordinal), value);
+ }
- // zero-out the bytes
- Platform.putLong(holder.buffer, holder.cursor, 0L);
- Platform.putLong(holder.buffer, holder.cursor + 8, 0L);
+ public void write(int ordinal, short value) {
+ Platform.putShort(holder.buffer, getFieldOffset(ordinal), value);
+ }
- // Make sure Decimal object has the same scale as DecimalType.
- // Note that we may pass in null Decimal object to set null for it.
- if (input == null || !input.changePrecision(precision, scale)) {
- BitSetMethods.set(holder.buffer, startingOffset, ordinal);
- // keep the offset for future update
- setOffsetAndSize(ordinal, 0L);
- } else {
- final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
- assert bytes.length <= 16;
+ public void write(int ordinal, int value) {
+ Platform.putInt(holder.buffer, getFieldOffset(ordinal), value);
+ }
- // Write the bytes to the variable length portion.
- Platform.copyMemory(
- bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length);
- setOffsetAndSize(ordinal, bytes.length);
+ public void write(int ordinal, long value) {
+ Platform.putLong(holder.buffer, getFieldOffset(ordinal), value);
+ }
+
+ public void write(int ordinal, float value) {
+ if (Float.isNaN(value)) {
+ value = Float.NaN;
}
+ Platform.putFloat(holder.buffer, getFieldOffset(ordinal), value);
+ }
- // move the cursor forward.
- holder.cursor += 16;
+ public void write(int ordinal, double value) {
+ if (Double.isNaN(value)) {
+ value = Double.NaN;
+ }
+ Platform.putDouble(holder.buffer, getFieldOffset(ordinal), value);
+ }
+
+ public void write(int ordinal, Decimal input, int precision, int scale) {
+ if (precision <= Decimal.MAX_LONG_DIGITS()) {
+ // make sure Decimal object has the same scale as DecimalType
+ if (input.changePrecision(precision, scale)) {
+ Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong());
+ } else {
+ setNullAt(ordinal);
+ }
+ } else {
+ // grow the global buffer before writing data.
+ holder.grow(16);
+
+ // zero-out the bytes
+ Platform.putLong(holder.buffer, holder.cursor, 0L);
+ Platform.putLong(holder.buffer, holder.cursor + 8, 0L);
+
+ // Make sure Decimal object has the same scale as DecimalType.
+ // Note that we may pass in null Decimal object to set null for it.
+ if (input == null || !input.changePrecision(precision, scale)) {
+ BitSetMethods.set(holder.buffer, startingOffset, ordinal);
+ // keep the offset for future update
+ setOffsetAndSize(ordinal, 0L);
+ } else {
+ final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
+ assert bytes.length <= 16;
+
+ // Write the bytes to the variable length portion.
+ Platform.copyMemory(
+ bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length);
+ setOffsetAndSize(ordinal, bytes.length);
+ }
+
+ // move the cursor forward.
+ holder.cursor += 16;
+ }
}
public void write(int ordinal, UTF8String input) {
@@ -151,7 +189,10 @@ public class UnsafeRowWriter {
}
public void write(int ordinal, byte[] input) {
- final int numBytes = input.length;
+ write(ordinal, input, 0, input.length);
+ }
+
+ public void write(int ordinal, byte[] input, int offset, int numBytes) {
final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
// grow the global buffer before writing data.
@@ -160,7 +201,8 @@ public class UnsafeRowWriter {
zeroOutPaddingBytes(numBytes);
// Write the bytes to the variable length portion.
- Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes);
+ Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET + offset,
+ holder.buffer, holder.cursor, numBytes);
setOffsetAndSize(ordinal, numBytes);
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index dbe92d6a83..2136f82ba4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -89,7 +89,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val setNull = dt match {
case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
// Can't call setNullAt() for DecimalType with precision larger than 18.
- s"$rowWriter.write($index, null, ${t.precision}, ${t.scale});"
+ s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});"
case _ => s"$rowWriter.setNullAt($index);"
}
@@ -124,17 +124,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
"""
case _ if ctx.isPrimitiveType(dt) =>
- val fieldOffset = ctx.freshName("fieldOffset")
s"""
- final long $fieldOffset = $rowWriter.getFieldOffset($index);
- Platform.putLong($bufferHolder.buffer, $fieldOffset, 0L);
- ${writePrimitiveType(ctx, input.value, dt, s"$bufferHolder.buffer", fieldOffset)}
+ $rowWriter.write($index, ${input.value});
"""
- case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
- s"$rowWriter.writeCompactDecimal($index, ${input.value}, " +
- s"${t.precision}, ${t.scale});"
-
case t: DecimalType =>
s"$rowWriter.write($index, ${input.value}, ${t.precision}, ${t.scale});"
@@ -204,20 +197,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)}
"""
- case _ if ctx.isPrimitiveType(et) =>
- // Should we do word align?
- val dataSize = et.defaultSize
-
- s"""
- $arrayWriter.setOffset($index);
- ${writePrimitiveType(ctx, element, et,
- s"$bufferHolder.buffer", s"$bufferHolder.cursor")}
- $bufferHolder.cursor += $dataSize;
- """
-
- case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
- s"$arrayWriter.writeCompactDecimal($index, $element, ${t.precision}, ${t.scale});"
-
case t: DecimalType =>
s"$arrayWriter.write($index, $element, ${t.precision}, ${t.scale});"
@@ -296,38 +275,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
"""
}
- private def writePrimitiveType(
- ctx: CodeGenContext,
- input: String,
- dt: DataType,
- buffer: String,
- offset: String) = {
- assert(ctx.isPrimitiveType(dt))
-
- val putMethod = s"put${ctx.primitiveTypeName(dt)}"
-
- dt match {
- case FloatType | DoubleType =>
- val normalized = ctx.freshName("normalized")
- val boxedType = ctx.boxedType(dt)
- val handleNaN =
- s"""
- final ${ctx.javaType(dt)} $normalized;
- if ($boxedType.isNaN($input)) {
- $normalized = $boxedType.NaN;
- } else {
- $normalized = $input;
- }
- """
-
- s"""
- $handleNaN
- Platform.$putMethod($buffer, $offset, $normalized);
- """
- case _ => s"Platform.$putMethod($buffer, $offset, $input);"
- }
- }
-
def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = {
val exprEvals = expressions.map(e => e.gen(ctx))
val exprTypes = expressions.map(_.dataType)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index 72fa299aa9..68e509eb50 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -32,6 +32,13 @@ import org.apache.spark.unsafe.types.UTF8String
/**
* A help class for fast reading Int/Long/Float/Double from ByteBuffer in native order.
*
+ * Note: There is not much difference between ByteBuffer.getByte/getShort and
+ * Unsafe.getByte/getShort, so we do not have helper methods for them.
+ *
+ * The unrolling (building columnar cache) is already slow, putLong/putDouble will not help much,
+ * so we do not have helper methods for them.
+ *
+ *
* WARNNING: This only works with HeapByteBuffer
*/
object ByteBufferHelper {
@@ -351,7 +358,38 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 2) {
}
}
-private[sql] object STRING extends NativeColumnType(StringType, 8) {
+/**
+ * A fast path to copy var-length bytes between ByteBuffer and UnsafeRow without creating wrapper
+ * objects.
+ */
+private[sql] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType] {
+
+ // copy the bytes from ByteBuffer to UnsafeRow
+ override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
+ if (row.isInstanceOf[MutableUnsafeRow]) {
+ val numBytes = buffer.getInt
+ val cursor = buffer.position()
+ buffer.position(cursor + numBytes)
+ row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, buffer.array(),
+ buffer.arrayOffset() + cursor, numBytes)
+ } else {
+ setField(row, ordinal, extract(buffer))
+ }
+ }
+
+ // copy the bytes from UnsafeRow to ByteBuffer
+ override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
+ if (row.isInstanceOf[UnsafeRow]) {
+ row.asInstanceOf[UnsafeRow].writeFieldTo(ordinal, buffer)
+ } else {
+ super.append(row, ordinal, buffer)
+ }
+ }
+}
+
+private[sql] object STRING
+ extends NativeColumnType(StringType, 8) with DirectCopyColumnType[UTF8String] {
+
override def actualSize(row: InternalRow, ordinal: Int): Int = {
row.getUTF8String(ordinal).numBytes() + 4
}
@@ -363,16 +401,17 @@ private[sql] object STRING extends NativeColumnType(StringType, 8) {
override def extract(buffer: ByteBuffer): UTF8String = {
val length = buffer.getInt()
- assert(buffer.hasArray)
- val base = buffer.array()
- val offset = buffer.arrayOffset()
val cursor = buffer.position()
buffer.position(cursor + length)
- UTF8String.fromBytes(base, offset + cursor, length)
+ UTF8String.fromBytes(buffer.array(), buffer.arrayOffset() + cursor, length)
}
override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = {
- row.update(ordinal, value.clone())
+ if (row.isInstanceOf[MutableUnsafeRow]) {
+ row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, value)
+ } else {
+ row.update(ordinal, value.clone())
+ }
}
override def getField(row: InternalRow, ordinal: Int): UTF8String = {
@@ -393,10 +432,28 @@ private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int)
Decimal(ByteBufferHelper.getLong(buffer), precision, scale)
}
+ override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
+ if (row.isInstanceOf[MutableUnsafeRow]) {
+ // copy it as Long
+ row.setLong(ordinal, ByteBufferHelper.getLong(buffer))
+ } else {
+ setField(row, ordinal, extract(buffer))
+ }
+ }
+
override def append(v: Decimal, buffer: ByteBuffer): Unit = {
buffer.putLong(v.toUnscaledLong)
}
+ override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
+ if (row.isInstanceOf[UnsafeRow]) {
+ // copy it as Long
+ buffer.putLong(row.getLong(ordinal))
+ } else {
+ append(getField(row, ordinal), buffer)
+ }
+ }
+
override def getField(row: InternalRow, ordinal: Int): Decimal = {
row.getDecimal(ordinal, precision, scale)
}
@@ -417,7 +474,7 @@ private[sql] object COMPACT_DECIMAL {
}
private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: Int)
- extends ColumnType[JvmType] {
+ extends ColumnType[JvmType] with DirectCopyColumnType[JvmType] {
def serialize(value: JvmType): Array[Byte]
def deserialize(bytes: Array[Byte]): JvmType
@@ -488,7 +545,8 @@ private[sql] object LARGE_DECIMAL {
}
}
-private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRow] {
+private[sql] case class STRUCT(dataType: StructType)
+ extends ColumnType[UnsafeRow] with DirectCopyColumnType[UnsafeRow] {
private val numOfFields: Int = dataType.fields.size
@@ -528,7 +586,8 @@ private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRo
override def clone(v: UnsafeRow): UnsafeRow = v.copy()
}
-private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArrayData] {
+private[sql] case class ARRAY(dataType: ArrayType)
+ extends ColumnType[UnsafeArrayData] with DirectCopyColumnType[UnsafeArrayData] {
override def defaultSize: Int = 16
@@ -566,7 +625,8 @@ private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArra
override def clone(v: UnsafeArrayData): UnsafeArrayData = v.copy()
}
-private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] {
+private[sql] case class MAP(dataType: MapType)
+ extends ColumnType[UnsafeMapData] with DirectCopyColumnType[UnsafeMapData] {
override def defaultSize: Int = 32
@@ -590,7 +650,6 @@ private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData]
override def extract(buffer: ByteBuffer): UnsafeMapData = {
val numBytes = buffer.getInt
- assert(buffer.hasArray)
val cursor = buffer.position()
buffer.position(cursor + numBytes)
val map = new UnsafeMapData
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala
index e04bcda580..d0f5bfa1cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala
@@ -20,18 +20,44 @@ package org.apache.spark.sql.columnar
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodeGenerator}
+import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeRowWriter, CodeFormatter, CodeGenerator}
import org.apache.spark.sql.types._
/**
- * An Iterator to walk throught the InternalRows from a CachedBatch
+ * An Iterator to walk through the InternalRows from a CachedBatch
*/
abstract class ColumnarIterator extends Iterator[InternalRow] {
- def initialize(input: Iterator[CachedBatch], mutableRow: MutableRow, columnTypes: Array[DataType],
+ def initialize(input: Iterator[CachedBatch], columnTypes: Array[DataType],
columnIndexes: Array[Int]): Unit
}
/**
+ * An helper class to update the fields of UnsafeRow, used by ColumnAccessor
+ *
+ * WARNNING: These setter MUST be called in increasing order of ordinals.
+ */
+class MutableUnsafeRow(val writer: UnsafeRowWriter) extends GenericMutableRow(null) {
+
+ override def isNullAt(i: Int): Boolean = writer.isNullAt(i)
+ override def setNullAt(i: Int): Unit = writer.setNullAt(i)
+
+ override def setBoolean(i: Int, v: Boolean): Unit = writer.write(i, v)
+ override def setByte(i: Int, v: Byte): Unit = writer.write(i, v)
+ override def setShort(i: Int, v: Short): Unit = writer.write(i, v)
+ override def setInt(i: Int, v: Int): Unit = writer.write(i, v)
+ override def setLong(i: Int, v: Long): Unit = writer.write(i, v)
+ override def setFloat(i: Int, v: Float): Unit = writer.write(i, v)
+ override def setDouble(i: Int, v: Double): Unit = writer.write(i, v)
+
+ // the writer will be used directly to avoid creating wrapper objects
+ override def setDecimal(i: Int, v: Decimal, precision: Int): Unit =
+ throw new UnsupportedOperationException
+ override def update(i: Int, v: Any): Unit = throw new UnsupportedOperationException
+
+ // all other methods inherited from GenericMutableRow are not need
+}
+
+/**
* Generates bytecode for an [[ColumnarIterator]] for columnar cache.
*/
object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarIterator] with Logging {
@@ -41,6 +67,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
protected def create(columnTypes: Seq[DataType]): ColumnarIterator = {
val ctx = newCodeGenContext()
+ val numFields = columnTypes.size
val (initializeAccessors, extractors) = columnTypes.zipWithIndex.map { case (dt, index) =>
val accessorName = ctx.freshName("accessor")
val accessorCls = dt match {
@@ -74,13 +101,27 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
}
val extract = s"$accessorName.extractTo(mutableRow, $index);"
-
- (createCode, extract)
+ val patch = dt match {
+ case DecimalType.Fixed(p, s) if p > Decimal.MAX_LONG_DIGITS =>
+ // For large Decimal, it should have 16 bytes for future update even it's null now.
+ s"""
+ if (mutableRow.isNullAt($index)) {
+ rowWriter.write($index, (Decimal) null, $p, $s);
+ }
+ """
+ case other => ""
+ }
+ (createCode, extract + patch)
}.unzip
val code = s"""
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
+ import scala.collection.Iterator;
+ import org.apache.spark.sql.types.DataType;
+ import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder;
+ import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter;
+ import org.apache.spark.sql.columnar.MutableUnsafeRow;
public SpecificColumnarIterator generate($exprType[] expr) {
return new SpecificColumnarIterator();
@@ -90,13 +131,17 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
private ByteOrder nativeOrder = null;
private byte[][] buffers = null;
+ private UnsafeRow unsafeRow = new UnsafeRow();
+ private BufferHolder bufferHolder = new BufferHolder();
+ private UnsafeRowWriter rowWriter = new UnsafeRowWriter();
+ private MutableUnsafeRow mutableRow = null;
private int currentRow = 0;
private int numRowsInBatch = 0;
private scala.collection.Iterator input = null;
private MutableRow mutableRow = null;
- private ${classOf[DataType].getName}[] columnTypes = null;
+ private DataType[] columnTypes = null;
private int[] columnIndexes = null;
${declareMutableStates(ctx)}
@@ -104,12 +149,12 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
public SpecificColumnarIterator() {
this.nativeOrder = ByteOrder.nativeOrder();
this.buffers = new byte[${columnTypes.length}][];
+ this.mutableRow = new MutableUnsafeRow(rowWriter);
${initMutableStates(ctx)}
}
- public void initialize(scala.collection.Iterator input, MutableRow mutableRow,
- ${classOf[DataType].getName}[] columnTypes, int[] columnIndexes) {
+ public void initialize(Iterator input, DataType[] columnTypes, int[] columnIndexes) {
this.input = input;
this.mutableRow = mutableRow;
this.columnTypes = columnTypes;
@@ -136,9 +181,12 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
}
public InternalRow next() {
- ${extractors.mkString("\n")}
currentRow += 1;
- return mutableRow;
+ bufferHolder.reset();
+ rowWriter.initialize(bufferHolder, $numFields);
+ ${extractors.mkString("\n")}
+ unsafeRow.pointTo(bufferHolder.buffer, $numFields, bufferHolder.totalSize());
+ return unsafeRow;
}
}"""
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index 9f76a61a15..b4607b12fc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -209,6 +209,8 @@ private[sql] case class InMemoryColumnarTableScan(
override def output: Seq[Attribute] = attributes
+ override def outputsUnsafeRows: Boolean = true
+
private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a)
// Returned filter predicate should return false iff it is impossible for the input expression
@@ -317,14 +319,12 @@ private[sql] case class InMemoryColumnarTableScan(
cachedBatchIterator
}
- val nextRow = new SpecificMutableRow(requestedColumnDataTypes)
val columnTypes = requestedColumnDataTypes.map {
case udt: UserDefinedType[_] => udt.sqlType
case other => other
}.toArray
val columnarIterator = GenerateColumnAccessor.generate(columnTypes)
- columnarIterator.initialize(cachedBatchesToScan, nextRow, columnTypes,
- requestedColumnIndices.toArray)
+ columnarIterator.initialize(cachedBatchesToScan, columnTypes, requestedColumnIndices.toArray)
if (enableAccumulators && columnarIterator.hasNext) {
readPartitions += 1
}