aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-10-21 19:20:31 -0700
committerReynold Xin <rxin@databricks.com>2015-10-21 19:20:31 -0700
commit1d9733271595596683a6d956a7433fa601df1cc1 (patch)
treedfe891e5f6bd28726f99dcc092d43e237ce154f3 /sql/core
parent40a10d7675578f8370d07e23810d9fc5d58e0550 (diff)
downloadspark-1d9733271595596683a6d956a7433fa601df1cc1.tar.gz
spark-1d9733271595596683a6d956a7433fa601df1cc1.tar.bz2
spark-1d9733271595596683a6d956a7433fa601df1cc1.zip
[SPARK-11243][SQL] output UnsafeRow from columnar cache
This PR change InMemoryTableScan to output UnsafeRow, and optimize the unrolling and scanning by coping the bytes for var-length types between UnsafeRow and ByteBuffer directly without creating the wrapper objects. When scanning the decimals in TPC-DS store_sales table, it's 80% faster (copy it as long without create Decimal objects). Author: Davies Liu <davies@databricks.com> Closes #9203 from davies/unsafe_cache.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala81
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala68
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala6
3 files changed, 131 insertions, 24 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index 72fa299aa9..68e509eb50 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -32,6 +32,13 @@ import org.apache.spark.unsafe.types.UTF8String
/**
* A help class for fast reading Int/Long/Float/Double from ByteBuffer in native order.
*
+ * Note: There is not much difference between ByteBuffer.getByte/getShort and
+ * Unsafe.getByte/getShort, so we do not have helper methods for them.
+ *
+ * The unrolling (building columnar cache) is already slow, putLong/putDouble will not help much,
+ * so we do not have helper methods for them.
+ *
+ *
* WARNNING: This only works with HeapByteBuffer
*/
object ByteBufferHelper {
@@ -351,7 +358,38 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 2) {
}
}
-private[sql] object STRING extends NativeColumnType(StringType, 8) {
+/**
+ * A fast path to copy var-length bytes between ByteBuffer and UnsafeRow without creating wrapper
+ * objects.
+ */
+private[sql] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType] {
+
+ // copy the bytes from ByteBuffer to UnsafeRow
+ override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
+ if (row.isInstanceOf[MutableUnsafeRow]) {
+ val numBytes = buffer.getInt
+ val cursor = buffer.position()
+ buffer.position(cursor + numBytes)
+ row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, buffer.array(),
+ buffer.arrayOffset() + cursor, numBytes)
+ } else {
+ setField(row, ordinal, extract(buffer))
+ }
+ }
+
+ // copy the bytes from UnsafeRow to ByteBuffer
+ override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
+ if (row.isInstanceOf[UnsafeRow]) {
+ row.asInstanceOf[UnsafeRow].writeFieldTo(ordinal, buffer)
+ } else {
+ super.append(row, ordinal, buffer)
+ }
+ }
+}
+
+private[sql] object STRING
+ extends NativeColumnType(StringType, 8) with DirectCopyColumnType[UTF8String] {
+
override def actualSize(row: InternalRow, ordinal: Int): Int = {
row.getUTF8String(ordinal).numBytes() + 4
}
@@ -363,16 +401,17 @@ private[sql] object STRING extends NativeColumnType(StringType, 8) {
override def extract(buffer: ByteBuffer): UTF8String = {
val length = buffer.getInt()
- assert(buffer.hasArray)
- val base = buffer.array()
- val offset = buffer.arrayOffset()
val cursor = buffer.position()
buffer.position(cursor + length)
- UTF8String.fromBytes(base, offset + cursor, length)
+ UTF8String.fromBytes(buffer.array(), buffer.arrayOffset() + cursor, length)
}
override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = {
- row.update(ordinal, value.clone())
+ if (row.isInstanceOf[MutableUnsafeRow]) {
+ row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, value)
+ } else {
+ row.update(ordinal, value.clone())
+ }
}
override def getField(row: InternalRow, ordinal: Int): UTF8String = {
@@ -393,10 +432,28 @@ private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int)
Decimal(ByteBufferHelper.getLong(buffer), precision, scale)
}
+ override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
+ if (row.isInstanceOf[MutableUnsafeRow]) {
+ // copy it as Long
+ row.setLong(ordinal, ByteBufferHelper.getLong(buffer))
+ } else {
+ setField(row, ordinal, extract(buffer))
+ }
+ }
+
override def append(v: Decimal, buffer: ByteBuffer): Unit = {
buffer.putLong(v.toUnscaledLong)
}
+ override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
+ if (row.isInstanceOf[UnsafeRow]) {
+ // copy it as Long
+ buffer.putLong(row.getLong(ordinal))
+ } else {
+ append(getField(row, ordinal), buffer)
+ }
+ }
+
override def getField(row: InternalRow, ordinal: Int): Decimal = {
row.getDecimal(ordinal, precision, scale)
}
@@ -417,7 +474,7 @@ private[sql] object COMPACT_DECIMAL {
}
private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: Int)
- extends ColumnType[JvmType] {
+ extends ColumnType[JvmType] with DirectCopyColumnType[JvmType] {
def serialize(value: JvmType): Array[Byte]
def deserialize(bytes: Array[Byte]): JvmType
@@ -488,7 +545,8 @@ private[sql] object LARGE_DECIMAL {
}
}
-private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRow] {
+private[sql] case class STRUCT(dataType: StructType)
+ extends ColumnType[UnsafeRow] with DirectCopyColumnType[UnsafeRow] {
private val numOfFields: Int = dataType.fields.size
@@ -528,7 +586,8 @@ private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRo
override def clone(v: UnsafeRow): UnsafeRow = v.copy()
}
-private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArrayData] {
+private[sql] case class ARRAY(dataType: ArrayType)
+ extends ColumnType[UnsafeArrayData] with DirectCopyColumnType[UnsafeArrayData] {
override def defaultSize: Int = 16
@@ -566,7 +625,8 @@ private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArra
override def clone(v: UnsafeArrayData): UnsafeArrayData = v.copy()
}
-private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] {
+private[sql] case class MAP(dataType: MapType)
+ extends ColumnType[UnsafeMapData] with DirectCopyColumnType[UnsafeMapData] {
override def defaultSize: Int = 32
@@ -590,7 +650,6 @@ private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData]
override def extract(buffer: ByteBuffer): UnsafeMapData = {
val numBytes = buffer.getInt
- assert(buffer.hasArray)
val cursor = buffer.position()
buffer.position(cursor + numBytes)
val map = new UnsafeMapData
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala
index e04bcda580..d0f5bfa1cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala
@@ -20,18 +20,44 @@ package org.apache.spark.sql.columnar
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodeGenerator}
+import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeRowWriter, CodeFormatter, CodeGenerator}
import org.apache.spark.sql.types._
/**
- * An Iterator to walk throught the InternalRows from a CachedBatch
+ * An Iterator to walk through the InternalRows from a CachedBatch
*/
abstract class ColumnarIterator extends Iterator[InternalRow] {
- def initialize(input: Iterator[CachedBatch], mutableRow: MutableRow, columnTypes: Array[DataType],
+ def initialize(input: Iterator[CachedBatch], columnTypes: Array[DataType],
columnIndexes: Array[Int]): Unit
}
/**
+ * An helper class to update the fields of UnsafeRow, used by ColumnAccessor
+ *
+ * WARNNING: These setter MUST be called in increasing order of ordinals.
+ */
+class MutableUnsafeRow(val writer: UnsafeRowWriter) extends GenericMutableRow(null) {
+
+ override def isNullAt(i: Int): Boolean = writer.isNullAt(i)
+ override def setNullAt(i: Int): Unit = writer.setNullAt(i)
+
+ override def setBoolean(i: Int, v: Boolean): Unit = writer.write(i, v)
+ override def setByte(i: Int, v: Byte): Unit = writer.write(i, v)
+ override def setShort(i: Int, v: Short): Unit = writer.write(i, v)
+ override def setInt(i: Int, v: Int): Unit = writer.write(i, v)
+ override def setLong(i: Int, v: Long): Unit = writer.write(i, v)
+ override def setFloat(i: Int, v: Float): Unit = writer.write(i, v)
+ override def setDouble(i: Int, v: Double): Unit = writer.write(i, v)
+
+ // the writer will be used directly to avoid creating wrapper objects
+ override def setDecimal(i: Int, v: Decimal, precision: Int): Unit =
+ throw new UnsupportedOperationException
+ override def update(i: Int, v: Any): Unit = throw new UnsupportedOperationException
+
+ // all other methods inherited from GenericMutableRow are not need
+}
+
+/**
* Generates bytecode for an [[ColumnarIterator]] for columnar cache.
*/
object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarIterator] with Logging {
@@ -41,6 +67,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
protected def create(columnTypes: Seq[DataType]): ColumnarIterator = {
val ctx = newCodeGenContext()
+ val numFields = columnTypes.size
val (initializeAccessors, extractors) = columnTypes.zipWithIndex.map { case (dt, index) =>
val accessorName = ctx.freshName("accessor")
val accessorCls = dt match {
@@ -74,13 +101,27 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
}
val extract = s"$accessorName.extractTo(mutableRow, $index);"
-
- (createCode, extract)
+ val patch = dt match {
+ case DecimalType.Fixed(p, s) if p > Decimal.MAX_LONG_DIGITS =>
+ // For large Decimal, it should have 16 bytes for future update even it's null now.
+ s"""
+ if (mutableRow.isNullAt($index)) {
+ rowWriter.write($index, (Decimal) null, $p, $s);
+ }
+ """
+ case other => ""
+ }
+ (createCode, extract + patch)
}.unzip
val code = s"""
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
+ import scala.collection.Iterator;
+ import org.apache.spark.sql.types.DataType;
+ import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder;
+ import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter;
+ import org.apache.spark.sql.columnar.MutableUnsafeRow;
public SpecificColumnarIterator generate($exprType[] expr) {
return new SpecificColumnarIterator();
@@ -90,13 +131,17 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
private ByteOrder nativeOrder = null;
private byte[][] buffers = null;
+ private UnsafeRow unsafeRow = new UnsafeRow();
+ private BufferHolder bufferHolder = new BufferHolder();
+ private UnsafeRowWriter rowWriter = new UnsafeRowWriter();
+ private MutableUnsafeRow mutableRow = null;
private int currentRow = 0;
private int numRowsInBatch = 0;
private scala.collection.Iterator input = null;
private MutableRow mutableRow = null;
- private ${classOf[DataType].getName}[] columnTypes = null;
+ private DataType[] columnTypes = null;
private int[] columnIndexes = null;
${declareMutableStates(ctx)}
@@ -104,12 +149,12 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
public SpecificColumnarIterator() {
this.nativeOrder = ByteOrder.nativeOrder();
this.buffers = new byte[${columnTypes.length}][];
+ this.mutableRow = new MutableUnsafeRow(rowWriter);
${initMutableStates(ctx)}
}
- public void initialize(scala.collection.Iterator input, MutableRow mutableRow,
- ${classOf[DataType].getName}[] columnTypes, int[] columnIndexes) {
+ public void initialize(Iterator input, DataType[] columnTypes, int[] columnIndexes) {
this.input = input;
this.mutableRow = mutableRow;
this.columnTypes = columnTypes;
@@ -136,9 +181,12 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
}
public InternalRow next() {
- ${extractors.mkString("\n")}
currentRow += 1;
- return mutableRow;
+ bufferHolder.reset();
+ rowWriter.initialize(bufferHolder, $numFields);
+ ${extractors.mkString("\n")}
+ unsafeRow.pointTo(bufferHolder.buffer, $numFields, bufferHolder.totalSize());
+ return unsafeRow;
}
}"""
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index 9f76a61a15..b4607b12fc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -209,6 +209,8 @@ private[sql] case class InMemoryColumnarTableScan(
override def output: Seq[Attribute] = attributes
+ override def outputsUnsafeRows: Boolean = true
+
private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a)
// Returned filter predicate should return false iff it is impossible for the input expression
@@ -317,14 +319,12 @@ private[sql] case class InMemoryColumnarTableScan(
cachedBatchIterator
}
- val nextRow = new SpecificMutableRow(requestedColumnDataTypes)
val columnTypes = requestedColumnDataTypes.map {
case udt: UserDefinedType[_] => udt.sqlType
case other => other
}.toArray
val columnarIterator = GenerateColumnAccessor.generate(columnTypes)
- columnarIterator.initialize(cachedBatchesToScan, nextRow, columnTypes,
- requestedColumnIndices.toArray)
+ columnarIterator.initialize(cachedBatchesToScan, columnTypes, requestedColumnIndices.toArray)
if (enableAccumulators && columnarIterator.hasNext) {
readPartitions += 1
}