diff options
author | Davies Liu <davies@databricks.com> | 2015-10-21 19:20:31 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-10-21 19:20:31 -0700 |
commit | 1d9733271595596683a6d956a7433fa601df1cc1 (patch) | |
tree | dfe891e5f6bd28726f99dcc092d43e237ce154f3 /sql/core | |
parent | 40a10d7675578f8370d07e23810d9fc5d58e0550 (diff) | |
download | spark-1d9733271595596683a6d956a7433fa601df1cc1.tar.gz spark-1d9733271595596683a6d956a7433fa601df1cc1.tar.bz2 spark-1d9733271595596683a6d956a7433fa601df1cc1.zip |
[SPARK-11243][SQL] output UnsafeRow from columnar cache
This PR change InMemoryTableScan to output UnsafeRow, and optimize the unrolling and scanning by coping the bytes for var-length types between UnsafeRow and ByteBuffer directly without creating the wrapper objects. When scanning the decimals in TPC-DS store_sales table, it's 80% faster (copy it as long without create Decimal objects).
Author: Davies Liu <davies@databricks.com>
Closes #9203 from davies/unsafe_cache.
Diffstat (limited to 'sql/core')
3 files changed, 131 insertions, 24 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 72fa299aa9..68e509eb50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -32,6 +32,13 @@ import org.apache.spark.unsafe.types.UTF8String /** * A help class for fast reading Int/Long/Float/Double from ByteBuffer in native order. * + * Note: There is not much difference between ByteBuffer.getByte/getShort and + * Unsafe.getByte/getShort, so we do not have helper methods for them. + * + * The unrolling (building columnar cache) is already slow, putLong/putDouble will not help much, + * so we do not have helper methods for them. + * + * * WARNNING: This only works with HeapByteBuffer */ object ByteBufferHelper { @@ -351,7 +358,38 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 2) { } } -private[sql] object STRING extends NativeColumnType(StringType, 8) { +/** + * A fast path to copy var-length bytes between ByteBuffer and UnsafeRow without creating wrapper + * objects. + */ +private[sql] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType] { + + // copy the bytes from ByteBuffer to UnsafeRow + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + if (row.isInstanceOf[MutableUnsafeRow]) { + val numBytes = buffer.getInt + val cursor = buffer.position() + buffer.position(cursor + numBytes) + row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, buffer.array(), + buffer.arrayOffset() + cursor, numBytes) + } else { + setField(row, ordinal, extract(buffer)) + } + } + + // copy the bytes from UnsafeRow to ByteBuffer + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + if (row.isInstanceOf[UnsafeRow]) { + row.asInstanceOf[UnsafeRow].writeFieldTo(ordinal, buffer) + } else { + super.append(row, ordinal, buffer) + } + } +} + +private[sql] object STRING + extends NativeColumnType(StringType, 8) with DirectCopyColumnType[UTF8String] { + override def actualSize(row: InternalRow, ordinal: Int): Int = { row.getUTF8String(ordinal).numBytes() + 4 } @@ -363,16 +401,17 @@ private[sql] object STRING extends NativeColumnType(StringType, 8) { override def extract(buffer: ByteBuffer): UTF8String = { val length = buffer.getInt() - assert(buffer.hasArray) - val base = buffer.array() - val offset = buffer.arrayOffset() val cursor = buffer.position() buffer.position(cursor + length) - UTF8String.fromBytes(base, offset + cursor, length) + UTF8String.fromBytes(buffer.array(), buffer.arrayOffset() + cursor, length) } override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = { - row.update(ordinal, value.clone()) + if (row.isInstanceOf[MutableUnsafeRow]) { + row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, value) + } else { + row.update(ordinal, value.clone()) + } } override def getField(row: InternalRow, ordinal: Int): UTF8String = { @@ -393,10 +432,28 @@ private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int) Decimal(ByteBufferHelper.getLong(buffer), precision, scale) } + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + if (row.isInstanceOf[MutableUnsafeRow]) { + // copy it as Long + row.setLong(ordinal, ByteBufferHelper.getLong(buffer)) + } else { + setField(row, ordinal, extract(buffer)) + } + } + override def append(v: Decimal, buffer: ByteBuffer): Unit = { buffer.putLong(v.toUnscaledLong) } + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + if (row.isInstanceOf[UnsafeRow]) { + // copy it as Long + buffer.putLong(row.getLong(ordinal)) + } else { + append(getField(row, ordinal), buffer) + } + } + override def getField(row: InternalRow, ordinal: Int): Decimal = { row.getDecimal(ordinal, precision, scale) } @@ -417,7 +474,7 @@ private[sql] object COMPACT_DECIMAL { } private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: Int) - extends ColumnType[JvmType] { + extends ColumnType[JvmType] with DirectCopyColumnType[JvmType] { def serialize(value: JvmType): Array[Byte] def deserialize(bytes: Array[Byte]): JvmType @@ -488,7 +545,8 @@ private[sql] object LARGE_DECIMAL { } } -private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRow] { +private[sql] case class STRUCT(dataType: StructType) + extends ColumnType[UnsafeRow] with DirectCopyColumnType[UnsafeRow] { private val numOfFields: Int = dataType.fields.size @@ -528,7 +586,8 @@ private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRo override def clone(v: UnsafeRow): UnsafeRow = v.copy() } -private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArrayData] { +private[sql] case class ARRAY(dataType: ArrayType) + extends ColumnType[UnsafeArrayData] with DirectCopyColumnType[UnsafeArrayData] { override def defaultSize: Int = 16 @@ -566,7 +625,8 @@ private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArra override def clone(v: UnsafeArrayData): UnsafeArrayData = v.copy() } -private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] { +private[sql] case class MAP(dataType: MapType) + extends ColumnType[UnsafeMapData] with DirectCopyColumnType[UnsafeMapData] { override def defaultSize: Int = 32 @@ -590,7 +650,6 @@ private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] override def extract(buffer: ByteBuffer): UnsafeMapData = { val numBytes = buffer.getInt - assert(buffer.hasArray) val cursor = buffer.position() buffer.position(cursor + numBytes) val map = new UnsafeMapData diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala index e04bcda580..d0f5bfa1cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala @@ -20,18 +20,44 @@ package org.apache.spark.sql.columnar import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodeGenerator} +import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeRowWriter, CodeFormatter, CodeGenerator} import org.apache.spark.sql.types._ /** - * An Iterator to walk throught the InternalRows from a CachedBatch + * An Iterator to walk through the InternalRows from a CachedBatch */ abstract class ColumnarIterator extends Iterator[InternalRow] { - def initialize(input: Iterator[CachedBatch], mutableRow: MutableRow, columnTypes: Array[DataType], + def initialize(input: Iterator[CachedBatch], columnTypes: Array[DataType], columnIndexes: Array[Int]): Unit } /** + * An helper class to update the fields of UnsafeRow, used by ColumnAccessor + * + * WARNNING: These setter MUST be called in increasing order of ordinals. + */ +class MutableUnsafeRow(val writer: UnsafeRowWriter) extends GenericMutableRow(null) { + + override def isNullAt(i: Int): Boolean = writer.isNullAt(i) + override def setNullAt(i: Int): Unit = writer.setNullAt(i) + + override def setBoolean(i: Int, v: Boolean): Unit = writer.write(i, v) + override def setByte(i: Int, v: Byte): Unit = writer.write(i, v) + override def setShort(i: Int, v: Short): Unit = writer.write(i, v) + override def setInt(i: Int, v: Int): Unit = writer.write(i, v) + override def setLong(i: Int, v: Long): Unit = writer.write(i, v) + override def setFloat(i: Int, v: Float): Unit = writer.write(i, v) + override def setDouble(i: Int, v: Double): Unit = writer.write(i, v) + + // the writer will be used directly to avoid creating wrapper objects + override def setDecimal(i: Int, v: Decimal, precision: Int): Unit = + throw new UnsupportedOperationException + override def update(i: Int, v: Any): Unit = throw new UnsupportedOperationException + + // all other methods inherited from GenericMutableRow are not need +} + +/** * Generates bytecode for an [[ColumnarIterator]] for columnar cache. */ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarIterator] with Logging { @@ -41,6 +67,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera protected def create(columnTypes: Seq[DataType]): ColumnarIterator = { val ctx = newCodeGenContext() + val numFields = columnTypes.size val (initializeAccessors, extractors) = columnTypes.zipWithIndex.map { case (dt, index) => val accessorName = ctx.freshName("accessor") val accessorCls = dt match { @@ -74,13 +101,27 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera } val extract = s"$accessorName.extractTo(mutableRow, $index);" - - (createCode, extract) + val patch = dt match { + case DecimalType.Fixed(p, s) if p > Decimal.MAX_LONG_DIGITS => + // For large Decimal, it should have 16 bytes for future update even it's null now. + s""" + if (mutableRow.isNullAt($index)) { + rowWriter.write($index, (Decimal) null, $p, $s); + } + """ + case other => "" + } + (createCode, extract + patch) }.unzip val code = s""" import java.nio.ByteBuffer; import java.nio.ByteOrder; + import scala.collection.Iterator; + import org.apache.spark.sql.types.DataType; + import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; + import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; + import org.apache.spark.sql.columnar.MutableUnsafeRow; public SpecificColumnarIterator generate($exprType[] expr) { return new SpecificColumnarIterator(); @@ -90,13 +131,17 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private ByteOrder nativeOrder = null; private byte[][] buffers = null; + private UnsafeRow unsafeRow = new UnsafeRow(); + private BufferHolder bufferHolder = new BufferHolder(); + private UnsafeRowWriter rowWriter = new UnsafeRowWriter(); + private MutableUnsafeRow mutableRow = null; private int currentRow = 0; private int numRowsInBatch = 0; private scala.collection.Iterator input = null; private MutableRow mutableRow = null; - private ${classOf[DataType].getName}[] columnTypes = null; + private DataType[] columnTypes = null; private int[] columnIndexes = null; ${declareMutableStates(ctx)} @@ -104,12 +149,12 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera public SpecificColumnarIterator() { this.nativeOrder = ByteOrder.nativeOrder(); this.buffers = new byte[${columnTypes.length}][]; + this.mutableRow = new MutableUnsafeRow(rowWriter); ${initMutableStates(ctx)} } - public void initialize(scala.collection.Iterator input, MutableRow mutableRow, - ${classOf[DataType].getName}[] columnTypes, int[] columnIndexes) { + public void initialize(Iterator input, DataType[] columnTypes, int[] columnIndexes) { this.input = input; this.mutableRow = mutableRow; this.columnTypes = columnTypes; @@ -136,9 +181,12 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera } public InternalRow next() { - ${extractors.mkString("\n")} currentRow += 1; - return mutableRow; + bufferHolder.reset(); + rowWriter.initialize(bufferHolder, $numFields); + ${extractors.mkString("\n")} + unsafeRow.pointTo(bufferHolder.buffer, $numFields, bufferHolder.totalSize()); + return unsafeRow; } }""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 9f76a61a15..b4607b12fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -209,6 +209,8 @@ private[sql] case class InMemoryColumnarTableScan( override def output: Seq[Attribute] = attributes + override def outputsUnsafeRows: Boolean = true + private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) // Returned filter predicate should return false iff it is impossible for the input expression @@ -317,14 +319,12 @@ private[sql] case class InMemoryColumnarTableScan( cachedBatchIterator } - val nextRow = new SpecificMutableRow(requestedColumnDataTypes) val columnTypes = requestedColumnDataTypes.map { case udt: UserDefinedType[_] => udt.sqlType case other => other }.toArray val columnarIterator = GenerateColumnAccessor.generate(columnTypes) - columnarIterator.initialize(cachedBatchesToScan, nextRow, columnTypes, - requestedColumnIndices.toArray) + columnarIterator.initialize(cachedBatchesToScan, columnTypes, requestedColumnIndices.toArray) if (enableAccumulators && columnarIterator.hasNext) { readPartitions += 1 } |