diff options
author | Davies Liu <davies@databricks.com> | 2015-10-20 14:01:53 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2015-10-20 14:01:53 -0700 |
commit | 06e6b765d0c747b773d7f3be28ddb0543c955a1f (patch) | |
tree | 13ba86c25a5471f429f0dcf2d7e37ace474a0233 /sql/core | |
parent | 67d468f8d9172569ec9846edc6432240547696dd (diff) | |
download | spark-06e6b765d0c747b773d7f3be28ddb0543c955a1f.tar.gz spark-06e6b765d0c747b773d7f3be28ddb0543c955a1f.tar.bz2 spark-06e6b765d0c747b773d7f3be28ddb0543c955a1f.zip |
[SPARK-11149] [SQL] Improve cache performance for primitive types
This PR improve the performance by:
1) Generate an Iterator that take Iterator[CachedBatch] as input, and call accessors (unroll the loop for columns), avoid the expensive Iterator.flatMap.
2) Use Unsafe.getInt/getLong/getFloat/getDouble instead of ByteBuffer.getInt/getLong/getFloat/getDouble, the later one actually read byte by byte.
3) Remove the unnecessary copy() in Coalesce(), which is not related to memory cache, found during benchmark.
The following benchmark showed that we can speedup the columnar cache of int by 2x.
```
path = '/opt/tpcds/store_sales/'
int_cols = ['ss_sold_date_sk', 'ss_sold_time_sk', 'ss_item_sk','ss_customer_sk']
df = sqlContext.read.parquet(path).select(int_cols).cache()
df.count()
t = time.time()
print df.select("*")._jdf.queryExecution().toRdd().count()
print time.time() - t
```
Author: Davies Liu <davies@databricks.com>
Closes #9145 from davies/byte_buffer.
Diffstat (limited to 'sql/core')
6 files changed, 230 insertions, 78 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index a41f04dd3b..72fa299aa9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -28,6 +28,38 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.UTF8String + +/** + * A help class for fast reading Int/Long/Float/Double from ByteBuffer in native order. + * + * WARNNING: This only works with HeapByteBuffer + */ +object ByteBufferHelper { + def getInt(buffer: ByteBuffer): Int = { + val pos = buffer.position() + buffer.position(pos + 4) + Platform.getInt(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) + } + + def getLong(buffer: ByteBuffer): Long = { + val pos = buffer.position() + buffer.position(pos + 8) + Platform.getLong(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) + } + + def getFloat(buffer: ByteBuffer): Float = { + val pos = buffer.position() + buffer.position(pos + 4) + Platform.getFloat(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) + } + + def getDouble(buffer: ByteBuffer): Double = { + val pos = buffer.position() + buffer.position(pos + 8) + Platform.getDouble(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) + } +} + /** * An abstract class that represents type of a column. Used to append/extract Java objects into/from * the underlying [[ByteBuffer]] of a column. @@ -134,11 +166,11 @@ private[sql] object INT extends NativeColumnType(IntegerType, 4) { } override def extract(buffer: ByteBuffer): Int = { - buffer.getInt() + ByteBufferHelper.getInt(buffer) } override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { - row.setInt(ordinal, buffer.getInt()) + row.setInt(ordinal, ByteBufferHelper.getInt(buffer)) } override def setField(row: MutableRow, ordinal: Int, value: Int): Unit = { @@ -163,11 +195,11 @@ private[sql] object LONG extends NativeColumnType(LongType, 8) { } override def extract(buffer: ByteBuffer): Long = { - buffer.getLong() + ByteBufferHelper.getLong(buffer) } override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { - row.setLong(ordinal, buffer.getLong()) + row.setLong(ordinal, ByteBufferHelper.getLong(buffer)) } override def setField(row: MutableRow, ordinal: Int, value: Long): Unit = { @@ -191,11 +223,11 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 4) { } override def extract(buffer: ByteBuffer): Float = { - buffer.getFloat() + ByteBufferHelper.getFloat(buffer) } override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { - row.setFloat(ordinal, buffer.getFloat()) + row.setFloat(ordinal, ByteBufferHelper.getFloat(buffer)) } override def setField(row: MutableRow, ordinal: Int, value: Float): Unit = { @@ -219,11 +251,11 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 8) { } override def extract(buffer: ByteBuffer): Double = { - buffer.getDouble() + ByteBufferHelper.getDouble(buffer) } override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { - row.setDouble(ordinal, buffer.getDouble()) + row.setDouble(ordinal, ByteBufferHelper.getDouble(buffer)) } override def setField(row: MutableRow, ordinal: Int, value: Double): Unit = { @@ -358,7 +390,7 @@ private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int) extends NativeColumnType(DecimalType(precision, scale), 8) { override def extract(buffer: ByteBuffer): Decimal = { - Decimal(buffer.getLong(), precision, scale) + Decimal(ByteBufferHelper.getLong(buffer), precision, scale) } override def append(v: Decimal, buffer: ByteBuffer): Unit = { @@ -480,7 +512,7 @@ private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRo } override def extract(buffer: ByteBuffer): UnsafeRow = { - val sizeInBytes = buffer.getInt() + val sizeInBytes = ByteBufferHelper.getInt(buffer) assert(buffer.hasArray) val cursor = buffer.position() buffer.position(cursor + sizeInBytes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala new file mode 100644 index 0000000000..e04bcda580 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodeGenerator} +import org.apache.spark.sql.types._ + +/** + * An Iterator to walk throught the InternalRows from a CachedBatch + */ +abstract class ColumnarIterator extends Iterator[InternalRow] { + def initialize(input: Iterator[CachedBatch], mutableRow: MutableRow, columnTypes: Array[DataType], + columnIndexes: Array[Int]): Unit +} + +/** + * Generates bytecode for an [[ColumnarIterator]] for columnar cache. + */ +object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarIterator] with Logging { + + protected def canonicalize(in: Seq[DataType]): Seq[DataType] = in + protected def bind(in: Seq[DataType], inputSchema: Seq[Attribute]): Seq[DataType] = in + + protected def create(columnTypes: Seq[DataType]): ColumnarIterator = { + val ctx = newCodeGenContext() + val (initializeAccessors, extractors) = columnTypes.zipWithIndex.map { case (dt, index) => + val accessorName = ctx.freshName("accessor") + val accessorCls = dt match { + case NullType => classOf[NullColumnAccessor].getName + case BooleanType => classOf[BooleanColumnAccessor].getName + case ByteType => classOf[ByteColumnAccessor].getName + case ShortType => classOf[ShortColumnAccessor].getName + case IntegerType | DateType => classOf[IntColumnAccessor].getName + case LongType | TimestampType => classOf[LongColumnAccessor].getName + case FloatType => classOf[FloatColumnAccessor].getName + case DoubleType => classOf[DoubleColumnAccessor].getName + case StringType => classOf[StringColumnAccessor].getName + case BinaryType => classOf[BinaryColumnAccessor].getName + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => + classOf[CompactDecimalColumnAccessor].getName + case dt: DecimalType => classOf[DecimalColumnAccessor].getName + case struct: StructType => classOf[StructColumnAccessor].getName + case array: ArrayType => classOf[ArrayColumnAccessor].getName + case t: MapType => classOf[MapColumnAccessor].getName + } + ctx.addMutableState(accessorCls, accessorName, s"$accessorName = null;") + + val createCode = dt match { + case t if ctx.isPrimitiveType(dt) => + s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" + case NullType | StringType | BinaryType => + s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" + case other => + s"""$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder), + (${dt.getClass.getName}) columnTypes[$index]);""" + } + + val extract = s"$accessorName.extractTo(mutableRow, $index);" + + (createCode, extract) + }.unzip + + val code = s""" + import java.nio.ByteBuffer; + import java.nio.ByteOrder; + + public SpecificColumnarIterator generate($exprType[] expr) { + return new SpecificColumnarIterator(); + } + + class SpecificColumnarIterator extends ${classOf[ColumnarIterator].getName} { + + private ByteOrder nativeOrder = null; + private byte[][] buffers = null; + + private int currentRow = 0; + private int numRowsInBatch = 0; + + private scala.collection.Iterator input = null; + private MutableRow mutableRow = null; + private ${classOf[DataType].getName}[] columnTypes = null; + private int[] columnIndexes = null; + + ${declareMutableStates(ctx)} + + public SpecificColumnarIterator() { + this.nativeOrder = ByteOrder.nativeOrder(); + this.buffers = new byte[${columnTypes.length}][]; + + ${initMutableStates(ctx)} + } + + public void initialize(scala.collection.Iterator input, MutableRow mutableRow, + ${classOf[DataType].getName}[] columnTypes, int[] columnIndexes) { + this.input = input; + this.mutableRow = mutableRow; + this.columnTypes = columnTypes; + this.columnIndexes = columnIndexes; + } + + public boolean hasNext() { + if (currentRow < numRowsInBatch) { + return true; + } + if (!input.hasNext()) { + return false; + } + + ${classOf[CachedBatch].getName} batch = (${classOf[CachedBatch].getName}) input.next(); + currentRow = 0; + numRowsInBatch = batch.numRows(); + for (int i = 0; i < columnIndexes.length; i ++) { + buffers[i] = batch.buffers()[columnIndexes[i]]; + } + ${initializeAccessors.mkString("\n")} + + return hasNext(); + } + + public InternalRow next() { + ${extractors.mkString("\n")} + currentRow += 1; + return mutableRow; + } + }""" + + logDebug(s"Generated ColumnarIterator: ${CodeFormatter.format(code)}") + + compile(code).generate(ctx.references.toArray).asInstanceOf[ColumnarIterator] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index d967814f62..9f76a61a15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.columnar -import java.nio.ByteBuffer - import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD @@ -28,6 +26,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.execution.{ConvertToUnsafe, LeafNode, SparkPlan} +import org.apache.spark.sql.types.UserDefinedType import org.apache.spark.storage.StorageLevel import org.apache.spark.{Accumulable, Accumulator, Accumulators} @@ -43,7 +42,14 @@ private[sql] object InMemoryRelation { tableName)() } -private[sql] case class CachedBatch(buffers: Array[Array[Byte]], stats: InternalRow) +/** + * CachedBatch is a cached batch of rows. + * + * @param numRows The total number of rows in this batch + * @param buffers The buffers for serialized columns + * @param stats The stat of columns + */ +private[sql] case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) private[sql] case class InMemoryRelation( output: Seq[Attribute], @@ -151,7 +157,7 @@ private[sql] case class InMemoryRelation( .flatMap(_.values)) batchStats += stats - CachedBatch(columnBuilders.map(_.build().array()), stats) + CachedBatch(rowCount, columnBuilders.map(_.build().array()), stats) } def hasNext: Boolean = rowIterator.hasNext @@ -278,59 +284,15 @@ private[sql] case class InMemoryColumnarTableScan( val buffers = relation.cachedColumnBuffers buffers.mapPartitions { cachedBatchIterator => - val partitionFilter = newPredicate( - partitionFilters.reduceOption(And).getOrElse(Literal(true)), - schema) - - // Find the ordinals and data types of the requested columns. If none are requested, use the - // narrowest (the field with minimum default element size). - val (requestedColumnIndices, requestedColumnDataTypes) = if (attributes.isEmpty) { - val (narrowestOrdinal, narrowestDataType) = - relOutput.zipWithIndex.map { case (a, ordinal) => - ordinal -> a.dataType - } minBy { case (_, dataType) => - ColumnType(dataType).defaultSize - } - Seq(narrowestOrdinal) -> Seq(narrowestDataType) - } else { + val partitionFilter = newPredicate( + partitionFilters.reduceOption(And).getOrElse(Literal(true)), + schema) + + // Find the ordinals and data types of the requested columns. + val (requestedColumnIndices, requestedColumnDataTypes) = attributes.map { a => relOutput.indexWhere(_.exprId == a.exprId) -> a.dataType }.unzip - } - - val nextRow = new SpecificMutableRow(requestedColumnDataTypes) - - def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]): Iterator[InternalRow] = { - val rows = cacheBatches.flatMap { cachedBatch => - // Build column accessors - val columnAccessors = requestedColumnIndices.map { batchColumnIndex => - ColumnAccessor( - relOutput(batchColumnIndex).dataType, - ByteBuffer.wrap(cachedBatch.buffers(batchColumnIndex))) - } - - // Extract rows via column accessors - new Iterator[InternalRow] { - private[this] val rowLen = nextRow.numFields - override def next(): InternalRow = { - var i = 0 - while (i < rowLen) { - columnAccessors(i).extractTo(nextRow, i) - i += 1 - } - if (attributes.isEmpty) InternalRow.empty else nextRow - } - - override def hasNext: Boolean = columnAccessors(0).hasNext - } - } - - if (rows.hasNext && enableAccumulators) { - readPartitions += 1 - } - - rows - } // Do partition batch pruning if enabled val cachedBatchesToScan = @@ -355,7 +317,18 @@ private[sql] case class InMemoryColumnarTableScan( cachedBatchIterator } - cachedBatchesToRows(cachedBatchesToScan) + val nextRow = new SpecificMutableRow(requestedColumnDataTypes) + val columnTypes = requestedColumnDataTypes.map { + case udt: UserDefinedType[_] => udt.sqlType + case other => other + }.toArray + val columnarIterator = GenerateColumnAccessor.generate(columnTypes) + columnarIterator.initialize(cachedBatchesToScan, nextRow, columnTypes, + requestedColumnIndices.toArray) + if (enableAccumulators && columnarIterator.hasNext) { + readPartitions += 1 + } + columnarIterator } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala index 4d35650d4b..7eaecfe047 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala @@ -31,8 +31,8 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor { abstract override protected def initialize(): Unit = { nullsBuffer = underlyingBuffer.duplicate().order(ByteOrder.nativeOrder()) - nullCount = nullsBuffer.getInt() - nextNullIndex = if (nullCount > 0) nullsBuffer.getInt() else -1 + nullCount = ByteBufferHelper.getInt(nullsBuffer) + nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 pos = 0 underlyingBuffer.position(underlyingBuffer.position + 4 + nullCount * 4) @@ -44,7 +44,7 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor { seenNulls += 1 if (seenNulls < nullCount) { - nextNullIndex = nullsBuffer.getInt() + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) } row.setNullAt(ordinal) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala index ca910a99db..41c9a284e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -20,13 +20,11 @@ package org.apache.spark.sql.columnar.compression import java.nio.ByteBuffer import scala.collection.mutable -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.runtimeMirror + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} import org.apache.spark.sql.columnar._ import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils private[sql] case object PassThrough extends CompressionScheme { @@ -161,7 +159,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { override def next(row: MutableRow, ordinal: Int): Unit = { if (valueCount == run) { currentValue = columnType.extract(buffer) - run = buffer.getInt() + run = ByteBufferHelper.getInt(buffer) valueCount = 1 } else { valueCount += 1 @@ -271,7 +269,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { extends compression.Decoder[T] { private val dictionary: Array[Any] = { - val elementNum = buffer.getInt() + val elementNum = ByteBufferHelper.getInt(buffer) Array.fill[Any](elementNum)(columnType.extract(buffer).asInstanceOf[Any]) } @@ -352,7 +350,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme { } class Decoder(buffer: ByteBuffer) extends compression.Decoder[BooleanType.type] { - private val count = buffer.getInt() + private val count = ByteBufferHelper.getInt(buffer) private var currentWord = 0: Long @@ -363,7 +361,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme { visited += 1 if (bit == 0) { - currentWord = buffer.getLong() + currentWord = ByteBufferHelper.getLong(buffer) } row.setBoolean(ordinal, ((currentWord >> bit) & 1) != 0) @@ -447,7 +445,7 @@ private[sql] case object IntDelta extends CompressionScheme { override def next(row: MutableRow, ordinal: Int): Unit = { val delta = buffer.get() - prev = if (delta > Byte.MinValue) prev + delta else buffer.getInt() + prev = if (delta > Byte.MinValue) prev + delta else ByteBufferHelper.getInt(buffer) row.setInt(ordinal, prev) } } @@ -527,7 +525,7 @@ private[sql] case object LongDelta extends CompressionScheme { override def next(row: MutableRow, ordinal: Int): Unit = { val delta = buffer.get() - prev = if (delta > Byte.MinValue) prev + delta else buffer.getLong() + prev = if (delta > Byte.MinValue) prev + delta else ByteBufferHelper.getLong(buffer) row.setLong(ordinal, prev) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 4db9f4ee67..dc38fe59fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -271,7 +271,7 @@ case class Coalesce(numPartitions: Int, child: SparkPlan) extends UnaryNode { } protected override def doExecute(): RDD[InternalRow] = { - child.execute().map(_.copy()).coalesce(numPartitions, shuffle = false) + child.execute().coalesce(numPartitions, shuffle = false) } override def canProcessUnsafeRows: Boolean = true |