aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala52
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala149
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala83
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala2
6 files changed, 230 insertions, 78 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index a41f04dd3b..72fa299aa9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -28,6 +28,38 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A help class for fast reading Int/Long/Float/Double from ByteBuffer in native order.
+ *
+ * WARNNING: This only works with HeapByteBuffer
+ */
+object ByteBufferHelper {
+ def getInt(buffer: ByteBuffer): Int = {
+ val pos = buffer.position()
+ buffer.position(pos + 4)
+ Platform.getInt(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos)
+ }
+
+ def getLong(buffer: ByteBuffer): Long = {
+ val pos = buffer.position()
+ buffer.position(pos + 8)
+ Platform.getLong(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos)
+ }
+
+ def getFloat(buffer: ByteBuffer): Float = {
+ val pos = buffer.position()
+ buffer.position(pos + 4)
+ Platform.getFloat(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos)
+ }
+
+ def getDouble(buffer: ByteBuffer): Double = {
+ val pos = buffer.position()
+ buffer.position(pos + 8)
+ Platform.getDouble(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos)
+ }
+}
+
/**
* An abstract class that represents type of a column. Used to append/extract Java objects into/from
* the underlying [[ByteBuffer]] of a column.
@@ -134,11 +166,11 @@ private[sql] object INT extends NativeColumnType(IntegerType, 4) {
}
override def extract(buffer: ByteBuffer): Int = {
- buffer.getInt()
+ ByteBufferHelper.getInt(buffer)
}
override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
- row.setInt(ordinal, buffer.getInt())
+ row.setInt(ordinal, ByteBufferHelper.getInt(buffer))
}
override def setField(row: MutableRow, ordinal: Int, value: Int): Unit = {
@@ -163,11 +195,11 @@ private[sql] object LONG extends NativeColumnType(LongType, 8) {
}
override def extract(buffer: ByteBuffer): Long = {
- buffer.getLong()
+ ByteBufferHelper.getLong(buffer)
}
override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
- row.setLong(ordinal, buffer.getLong())
+ row.setLong(ordinal, ByteBufferHelper.getLong(buffer))
}
override def setField(row: MutableRow, ordinal: Int, value: Long): Unit = {
@@ -191,11 +223,11 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 4) {
}
override def extract(buffer: ByteBuffer): Float = {
- buffer.getFloat()
+ ByteBufferHelper.getFloat(buffer)
}
override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
- row.setFloat(ordinal, buffer.getFloat())
+ row.setFloat(ordinal, ByteBufferHelper.getFloat(buffer))
}
override def setField(row: MutableRow, ordinal: Int, value: Float): Unit = {
@@ -219,11 +251,11 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 8) {
}
override def extract(buffer: ByteBuffer): Double = {
- buffer.getDouble()
+ ByteBufferHelper.getDouble(buffer)
}
override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
- row.setDouble(ordinal, buffer.getDouble())
+ row.setDouble(ordinal, ByteBufferHelper.getDouble(buffer))
}
override def setField(row: MutableRow, ordinal: Int, value: Double): Unit = {
@@ -358,7 +390,7 @@ private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int)
extends NativeColumnType(DecimalType(precision, scale), 8) {
override def extract(buffer: ByteBuffer): Decimal = {
- Decimal(buffer.getLong(), precision, scale)
+ Decimal(ByteBufferHelper.getLong(buffer), precision, scale)
}
override def append(v: Decimal, buffer: ByteBuffer): Unit = {
@@ -480,7 +512,7 @@ private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRo
}
override def extract(buffer: ByteBuffer): UnsafeRow = {
- val sizeInBytes = buffer.getInt()
+ val sizeInBytes = ByteBufferHelper.getInt(buffer)
assert(buffer.hasArray)
val cursor = buffer.position()
buffer.position(cursor + sizeInBytes)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala
new file mode 100644
index 0000000000..e04bcda580
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodeGenerator}
+import org.apache.spark.sql.types._
+
+/**
+ * An Iterator to walk throught the InternalRows from a CachedBatch
+ */
+abstract class ColumnarIterator extends Iterator[InternalRow] {
+ def initialize(input: Iterator[CachedBatch], mutableRow: MutableRow, columnTypes: Array[DataType],
+ columnIndexes: Array[Int]): Unit
+}
+
+/**
+ * Generates bytecode for an [[ColumnarIterator]] for columnar cache.
+ */
+object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarIterator] with Logging {
+
+ protected def canonicalize(in: Seq[DataType]): Seq[DataType] = in
+ protected def bind(in: Seq[DataType], inputSchema: Seq[Attribute]): Seq[DataType] = in
+
+ protected def create(columnTypes: Seq[DataType]): ColumnarIterator = {
+ val ctx = newCodeGenContext()
+ val (initializeAccessors, extractors) = columnTypes.zipWithIndex.map { case (dt, index) =>
+ val accessorName = ctx.freshName("accessor")
+ val accessorCls = dt match {
+ case NullType => classOf[NullColumnAccessor].getName
+ case BooleanType => classOf[BooleanColumnAccessor].getName
+ case ByteType => classOf[ByteColumnAccessor].getName
+ case ShortType => classOf[ShortColumnAccessor].getName
+ case IntegerType | DateType => classOf[IntColumnAccessor].getName
+ case LongType | TimestampType => classOf[LongColumnAccessor].getName
+ case FloatType => classOf[FloatColumnAccessor].getName
+ case DoubleType => classOf[DoubleColumnAccessor].getName
+ case StringType => classOf[StringColumnAccessor].getName
+ case BinaryType => classOf[BinaryColumnAccessor].getName
+ case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
+ classOf[CompactDecimalColumnAccessor].getName
+ case dt: DecimalType => classOf[DecimalColumnAccessor].getName
+ case struct: StructType => classOf[StructColumnAccessor].getName
+ case array: ArrayType => classOf[ArrayColumnAccessor].getName
+ case t: MapType => classOf[MapColumnAccessor].getName
+ }
+ ctx.addMutableState(accessorCls, accessorName, s"$accessorName = null;")
+
+ val createCode = dt match {
+ case t if ctx.isPrimitiveType(dt) =>
+ s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
+ case NullType | StringType | BinaryType =>
+ s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
+ case other =>
+ s"""$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder),
+ (${dt.getClass.getName}) columnTypes[$index]);"""
+ }
+
+ val extract = s"$accessorName.extractTo(mutableRow, $index);"
+
+ (createCode, extract)
+ }.unzip
+
+ val code = s"""
+ import java.nio.ByteBuffer;
+ import java.nio.ByteOrder;
+
+ public SpecificColumnarIterator generate($exprType[] expr) {
+ return new SpecificColumnarIterator();
+ }
+
+ class SpecificColumnarIterator extends ${classOf[ColumnarIterator].getName} {
+
+ private ByteOrder nativeOrder = null;
+ private byte[][] buffers = null;
+
+ private int currentRow = 0;
+ private int numRowsInBatch = 0;
+
+ private scala.collection.Iterator input = null;
+ private MutableRow mutableRow = null;
+ private ${classOf[DataType].getName}[] columnTypes = null;
+ private int[] columnIndexes = null;
+
+ ${declareMutableStates(ctx)}
+
+ public SpecificColumnarIterator() {
+ this.nativeOrder = ByteOrder.nativeOrder();
+ this.buffers = new byte[${columnTypes.length}][];
+
+ ${initMutableStates(ctx)}
+ }
+
+ public void initialize(scala.collection.Iterator input, MutableRow mutableRow,
+ ${classOf[DataType].getName}[] columnTypes, int[] columnIndexes) {
+ this.input = input;
+ this.mutableRow = mutableRow;
+ this.columnTypes = columnTypes;
+ this.columnIndexes = columnIndexes;
+ }
+
+ public boolean hasNext() {
+ if (currentRow < numRowsInBatch) {
+ return true;
+ }
+ if (!input.hasNext()) {
+ return false;
+ }
+
+ ${classOf[CachedBatch].getName} batch = (${classOf[CachedBatch].getName}) input.next();
+ currentRow = 0;
+ numRowsInBatch = batch.numRows();
+ for (int i = 0; i < columnIndexes.length; i ++) {
+ buffers[i] = batch.buffers()[columnIndexes[i]];
+ }
+ ${initializeAccessors.mkString("\n")}
+
+ return hasNext();
+ }
+
+ public InternalRow next() {
+ ${extractors.mkString("\n")}
+ currentRow += 1;
+ return mutableRow;
+ }
+ }"""
+
+ logDebug(s"Generated ColumnarIterator: ${CodeFormatter.format(code)}")
+
+ compile(code).generate(ctx.references.toArray).asInstanceOf[ColumnarIterator]
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index d967814f62..9f76a61a15 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.columnar
-import java.nio.ByteBuffer
-
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.rdd.RDD
@@ -28,6 +26,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
import org.apache.spark.sql.execution.{ConvertToUnsafe, LeafNode, SparkPlan}
+import org.apache.spark.sql.types.UserDefinedType
import org.apache.spark.storage.StorageLevel
import org.apache.spark.{Accumulable, Accumulator, Accumulators}
@@ -43,7 +42,14 @@ private[sql] object InMemoryRelation {
tableName)()
}
-private[sql] case class CachedBatch(buffers: Array[Array[Byte]], stats: InternalRow)
+/**
+ * CachedBatch is a cached batch of rows.
+ *
+ * @param numRows The total number of rows in this batch
+ * @param buffers The buffers for serialized columns
+ * @param stats The stat of columns
+ */
+private[sql] case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow)
private[sql] case class InMemoryRelation(
output: Seq[Attribute],
@@ -151,7 +157,7 @@ private[sql] case class InMemoryRelation(
.flatMap(_.values))
batchStats += stats
- CachedBatch(columnBuilders.map(_.build().array()), stats)
+ CachedBatch(rowCount, columnBuilders.map(_.build().array()), stats)
}
def hasNext: Boolean = rowIterator.hasNext
@@ -278,59 +284,15 @@ private[sql] case class InMemoryColumnarTableScan(
val buffers = relation.cachedColumnBuffers
buffers.mapPartitions { cachedBatchIterator =>
- val partitionFilter = newPredicate(
- partitionFilters.reduceOption(And).getOrElse(Literal(true)),
- schema)
-
- // Find the ordinals and data types of the requested columns. If none are requested, use the
- // narrowest (the field with minimum default element size).
- val (requestedColumnIndices, requestedColumnDataTypes) = if (attributes.isEmpty) {
- val (narrowestOrdinal, narrowestDataType) =
- relOutput.zipWithIndex.map { case (a, ordinal) =>
- ordinal -> a.dataType
- } minBy { case (_, dataType) =>
- ColumnType(dataType).defaultSize
- }
- Seq(narrowestOrdinal) -> Seq(narrowestDataType)
- } else {
+ val partitionFilter = newPredicate(
+ partitionFilters.reduceOption(And).getOrElse(Literal(true)),
+ schema)
+
+ // Find the ordinals and data types of the requested columns.
+ val (requestedColumnIndices, requestedColumnDataTypes) =
attributes.map { a =>
relOutput.indexWhere(_.exprId == a.exprId) -> a.dataType
}.unzip
- }
-
- val nextRow = new SpecificMutableRow(requestedColumnDataTypes)
-
- def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]): Iterator[InternalRow] = {
- val rows = cacheBatches.flatMap { cachedBatch =>
- // Build column accessors
- val columnAccessors = requestedColumnIndices.map { batchColumnIndex =>
- ColumnAccessor(
- relOutput(batchColumnIndex).dataType,
- ByteBuffer.wrap(cachedBatch.buffers(batchColumnIndex)))
- }
-
- // Extract rows via column accessors
- new Iterator[InternalRow] {
- private[this] val rowLen = nextRow.numFields
- override def next(): InternalRow = {
- var i = 0
- while (i < rowLen) {
- columnAccessors(i).extractTo(nextRow, i)
- i += 1
- }
- if (attributes.isEmpty) InternalRow.empty else nextRow
- }
-
- override def hasNext: Boolean = columnAccessors(0).hasNext
- }
- }
-
- if (rows.hasNext && enableAccumulators) {
- readPartitions += 1
- }
-
- rows
- }
// Do partition batch pruning if enabled
val cachedBatchesToScan =
@@ -355,7 +317,18 @@ private[sql] case class InMemoryColumnarTableScan(
cachedBatchIterator
}
- cachedBatchesToRows(cachedBatchesToScan)
+ val nextRow = new SpecificMutableRow(requestedColumnDataTypes)
+ val columnTypes = requestedColumnDataTypes.map {
+ case udt: UserDefinedType[_] => udt.sqlType
+ case other => other
+ }.toArray
+ val columnarIterator = GenerateColumnAccessor.generate(columnTypes)
+ columnarIterator.initialize(cachedBatchesToScan, nextRow, columnTypes,
+ requestedColumnIndices.toArray)
+ if (enableAccumulators && columnarIterator.hasNext) {
+ readPartitions += 1
+ }
+ columnarIterator
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala
index 4d35650d4b..7eaecfe047 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala
@@ -31,8 +31,8 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor {
abstract override protected def initialize(): Unit = {
nullsBuffer = underlyingBuffer.duplicate().order(ByteOrder.nativeOrder())
- nullCount = nullsBuffer.getInt()
- nextNullIndex = if (nullCount > 0) nullsBuffer.getInt() else -1
+ nullCount = ByteBufferHelper.getInt(nullsBuffer)
+ nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1
pos = 0
underlyingBuffer.position(underlyingBuffer.position + 4 + nullCount * 4)
@@ -44,7 +44,7 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor {
seenNulls += 1
if (seenNulls < nullCount) {
- nextNullIndex = nullsBuffer.getInt()
+ nextNullIndex = ByteBufferHelper.getInt(nullsBuffer)
}
row.setNullAt(ordinal)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
index ca910a99db..41c9a284e3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
@@ -20,13 +20,11 @@ package org.apache.spark.sql.columnar.compression
import java.nio.ByteBuffer
import scala.collection.mutable
-import scala.reflect.ClassTag
-import scala.reflect.runtime.universe.runtimeMirror
+
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow}
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.types._
-import org.apache.spark.util.Utils
private[sql] case object PassThrough extends CompressionScheme {
@@ -161,7 +159,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme {
override def next(row: MutableRow, ordinal: Int): Unit = {
if (valueCount == run) {
currentValue = columnType.extract(buffer)
- run = buffer.getInt()
+ run = ByteBufferHelper.getInt(buffer)
valueCount = 1
} else {
valueCount += 1
@@ -271,7 +269,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme {
extends compression.Decoder[T] {
private val dictionary: Array[Any] = {
- val elementNum = buffer.getInt()
+ val elementNum = ByteBufferHelper.getInt(buffer)
Array.fill[Any](elementNum)(columnType.extract(buffer).asInstanceOf[Any])
}
@@ -352,7 +350,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme {
}
class Decoder(buffer: ByteBuffer) extends compression.Decoder[BooleanType.type] {
- private val count = buffer.getInt()
+ private val count = ByteBufferHelper.getInt(buffer)
private var currentWord = 0: Long
@@ -363,7 +361,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme {
visited += 1
if (bit == 0) {
- currentWord = buffer.getLong()
+ currentWord = ByteBufferHelper.getLong(buffer)
}
row.setBoolean(ordinal, ((currentWord >> bit) & 1) != 0)
@@ -447,7 +445,7 @@ private[sql] case object IntDelta extends CompressionScheme {
override def next(row: MutableRow, ordinal: Int): Unit = {
val delta = buffer.get()
- prev = if (delta > Byte.MinValue) prev + delta else buffer.getInt()
+ prev = if (delta > Byte.MinValue) prev + delta else ByteBufferHelper.getInt(buffer)
row.setInt(ordinal, prev)
}
}
@@ -527,7 +525,7 @@ private[sql] case object LongDelta extends CompressionScheme {
override def next(row: MutableRow, ordinal: Int): Unit = {
val delta = buffer.get()
- prev = if (delta > Byte.MinValue) prev + delta else buffer.getLong()
+ prev = if (delta > Byte.MinValue) prev + delta else ByteBufferHelper.getLong(buffer)
row.setLong(ordinal, prev)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 4db9f4ee67..dc38fe59fe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -271,7 +271,7 @@ case class Coalesce(numPartitions: Int, child: SparkPlan) extends UnaryNode {
}
protected override def doExecute(): RDD[InternalRow] = {
- child.execute().map(_.copy()).coalesce(numPartitions, shuffle = false)
+ child.execute().coalesce(numPartitions, shuffle = false)
}
override def canProcessUnsafeRows: Boolean = true