aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-10-20 14:01:53 -0700
committerDavies Liu <davies.liu@gmail.com>2015-10-20 14:01:53 -0700
commit06e6b765d0c747b773d7f3be28ddb0543c955a1f (patch)
tree13ba86c25a5471f429f0dcf2d7e37ace474a0233 /sql
parent67d468f8d9172569ec9846edc6432240547696dd (diff)
downloadspark-06e6b765d0c747b773d7f3be28ddb0543c955a1f.tar.gz
spark-06e6b765d0c747b773d7f3be28ddb0543c955a1f.tar.bz2
spark-06e6b765d0c747b773d7f3be28ddb0543c955a1f.zip
[SPARK-11149] [SQL] Improve cache performance for primitive types
This PR improve the performance by: 1) Generate an Iterator that take Iterator[CachedBatch] as input, and call accessors (unroll the loop for columns), avoid the expensive Iterator.flatMap. 2) Use Unsafe.getInt/getLong/getFloat/getDouble instead of ByteBuffer.getInt/getLong/getFloat/getDouble, the later one actually read byte by byte. 3) Remove the unnecessary copy() in Coalesce(), which is not related to memory cache, found during benchmark. The following benchmark showed that we can speedup the columnar cache of int by 2x. ``` path = '/opt/tpcds/store_sales/' int_cols = ['ss_sold_date_sk', 'ss_sold_time_sk', 'ss_item_sk','ss_customer_sk'] df = sqlContext.read.parquet(path).select(int_cols).cache() df.count() t = time.time() print df.select("*")._jdf.queryExecution().toRdd().count() print time.time() - t ``` Author: Davies Liu <davies@databricks.com> Closes #9145 from davies/byte_buffer.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala10
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala66
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala52
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala149
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala83
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala2
9 files changed, 265 insertions, 122 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala
index c98182c96b..9b8b6382d7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala
@@ -32,6 +32,7 @@ private class CodeFormatter {
private var indentLevel = 0
private val indentSize = 2
private var indentString = ""
+ private var currentLine = 1
private def addLine(line: String): Unit = {
val indentChange =
@@ -44,11 +45,13 @@ private class CodeFormatter {
} else {
indentString
}
+ code.append(f"/* ${currentLine}%03d */ ")
code.append(thisLineIndent)
code.append(line)
code.append("\n")
indentLevel = newIndentLevel
indentString = " " * (indentSize * newIndentLevel)
+ currentLine += 1
}
private def addLines(code: String): CodeFormatter = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 7544d27e3d..a4ec5085fa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -391,26 +391,24 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
classOf[ArrayData].getName,
classOf[UnsafeArrayData].getName,
classOf[MapData].getName,
- classOf[UnsafeMapData].getName
+ classOf[UnsafeMapData].getName,
+ classOf[MutableRow].getName
))
evaluator.setExtendedClass(classOf[GeneratedClass])
def formatted = CodeFormatter.format(code)
- def withLineNums = formatted.split("\n").zipWithIndex.map {
- case (l, n) => f"${n + 1}%03d $l"
- }.mkString("\n")
logDebug({
// Only add extra debugging info to byte code when we are going to print the source code.
evaluator.setDebuggingInformation(true, true, false)
- withLineNums
+ formatted
})
try {
evaluator.cook("generated.java", code)
} catch {
case e: Exception =>
- val msg = s"failed to compile: $e\n$withLineNums"
+ val msg = s"failed to compile: $e\n$formatted"
logError(msg, e)
throw new Exception(msg, e)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala
index 46daa3eb8b..9da1068e9c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala
@@ -29,78 +29,68 @@ class CodeFormatterSuite extends SparkFunSuite {
}
testCase("basic example") {
- """
- |class A {
+ """class A {
|blahblah;
- |}
- """.stripMargin
+ |}""".stripMargin
}{
"""
- |class A {
- | blahblah;
- |}
+ |/* 001 */ class A {
+ |/* 002 */ blahblah;
+ |/* 003 */ }
""".stripMargin
}
testCase("nested example") {
- """
- |class A {
+ """class A {
| if (c) {
|duh;
|}
- |}
- """.stripMargin
+ |}""".stripMargin
} {
"""
- |class A {
- | if (c) {
- | duh;
- | }
- |}
+ |/* 001 */ class A {
+ |/* 002 */ if (c) {
+ |/* 003 */ duh;
+ |/* 004 */ }
+ |/* 005 */ }
""".stripMargin
}
testCase("single line") {
- """
- |class A {
+ """class A {
| if (c) {duh;}
- |}
- """.stripMargin
+ |}""".stripMargin
}{
"""
- |class A {
- | if (c) {duh;}
- |}
+ |/* 001 */ class A {
+ |/* 002 */ if (c) {duh;}
+ |/* 003 */ }
""".stripMargin
}
testCase("if else on the same line") {
- """
- |class A {
+ """class A {
| if (c) {duh;} else {boo;}
- |}
- """.stripMargin
+ |}""".stripMargin
}{
"""
- |class A {
- | if (c) {duh;} else {boo;}
- |}
+ |/* 001 */ class A {
+ |/* 002 */ if (c) {duh;} else {boo;}
+ |/* 003 */ }
""".stripMargin
}
testCase("function calls") {
- """
- |foo(
+ """foo(
|a,
|b,
- |c)
- """.stripMargin
+ |c)""".stripMargin
}{
"""
- |foo(
- | a,
- | b,
- | c)
+ |/* 001 */ foo(
+ |/* 002 */ a,
+ |/* 003 */ b,
+ |/* 004 */ c)
""".stripMargin
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index a41f04dd3b..72fa299aa9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -28,6 +28,38 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A help class for fast reading Int/Long/Float/Double from ByteBuffer in native order.
+ *
+ * WARNNING: This only works with HeapByteBuffer
+ */
+object ByteBufferHelper {
+ def getInt(buffer: ByteBuffer): Int = {
+ val pos = buffer.position()
+ buffer.position(pos + 4)
+ Platform.getInt(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos)
+ }
+
+ def getLong(buffer: ByteBuffer): Long = {
+ val pos = buffer.position()
+ buffer.position(pos + 8)
+ Platform.getLong(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos)
+ }
+
+ def getFloat(buffer: ByteBuffer): Float = {
+ val pos = buffer.position()
+ buffer.position(pos + 4)
+ Platform.getFloat(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos)
+ }
+
+ def getDouble(buffer: ByteBuffer): Double = {
+ val pos = buffer.position()
+ buffer.position(pos + 8)
+ Platform.getDouble(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos)
+ }
+}
+
/**
* An abstract class that represents type of a column. Used to append/extract Java objects into/from
* the underlying [[ByteBuffer]] of a column.
@@ -134,11 +166,11 @@ private[sql] object INT extends NativeColumnType(IntegerType, 4) {
}
override def extract(buffer: ByteBuffer): Int = {
- buffer.getInt()
+ ByteBufferHelper.getInt(buffer)
}
override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
- row.setInt(ordinal, buffer.getInt())
+ row.setInt(ordinal, ByteBufferHelper.getInt(buffer))
}
override def setField(row: MutableRow, ordinal: Int, value: Int): Unit = {
@@ -163,11 +195,11 @@ private[sql] object LONG extends NativeColumnType(LongType, 8) {
}
override def extract(buffer: ByteBuffer): Long = {
- buffer.getLong()
+ ByteBufferHelper.getLong(buffer)
}
override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
- row.setLong(ordinal, buffer.getLong())
+ row.setLong(ordinal, ByteBufferHelper.getLong(buffer))
}
override def setField(row: MutableRow, ordinal: Int, value: Long): Unit = {
@@ -191,11 +223,11 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 4) {
}
override def extract(buffer: ByteBuffer): Float = {
- buffer.getFloat()
+ ByteBufferHelper.getFloat(buffer)
}
override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
- row.setFloat(ordinal, buffer.getFloat())
+ row.setFloat(ordinal, ByteBufferHelper.getFloat(buffer))
}
override def setField(row: MutableRow, ordinal: Int, value: Float): Unit = {
@@ -219,11 +251,11 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 8) {
}
override def extract(buffer: ByteBuffer): Double = {
- buffer.getDouble()
+ ByteBufferHelper.getDouble(buffer)
}
override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
- row.setDouble(ordinal, buffer.getDouble())
+ row.setDouble(ordinal, ByteBufferHelper.getDouble(buffer))
}
override def setField(row: MutableRow, ordinal: Int, value: Double): Unit = {
@@ -358,7 +390,7 @@ private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int)
extends NativeColumnType(DecimalType(precision, scale), 8) {
override def extract(buffer: ByteBuffer): Decimal = {
- Decimal(buffer.getLong(), precision, scale)
+ Decimal(ByteBufferHelper.getLong(buffer), precision, scale)
}
override def append(v: Decimal, buffer: ByteBuffer): Unit = {
@@ -480,7 +512,7 @@ private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRo
}
override def extract(buffer: ByteBuffer): UnsafeRow = {
- val sizeInBytes = buffer.getInt()
+ val sizeInBytes = ByteBufferHelper.getInt(buffer)
assert(buffer.hasArray)
val cursor = buffer.position()
buffer.position(cursor + sizeInBytes)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala
new file mode 100644
index 0000000000..e04bcda580
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodeGenerator}
+import org.apache.spark.sql.types._
+
+/**
+ * An Iterator to walk throught the InternalRows from a CachedBatch
+ */
+abstract class ColumnarIterator extends Iterator[InternalRow] {
+ def initialize(input: Iterator[CachedBatch], mutableRow: MutableRow, columnTypes: Array[DataType],
+ columnIndexes: Array[Int]): Unit
+}
+
+/**
+ * Generates bytecode for an [[ColumnarIterator]] for columnar cache.
+ */
+object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarIterator] with Logging {
+
+ protected def canonicalize(in: Seq[DataType]): Seq[DataType] = in
+ protected def bind(in: Seq[DataType], inputSchema: Seq[Attribute]): Seq[DataType] = in
+
+ protected def create(columnTypes: Seq[DataType]): ColumnarIterator = {
+ val ctx = newCodeGenContext()
+ val (initializeAccessors, extractors) = columnTypes.zipWithIndex.map { case (dt, index) =>
+ val accessorName = ctx.freshName("accessor")
+ val accessorCls = dt match {
+ case NullType => classOf[NullColumnAccessor].getName
+ case BooleanType => classOf[BooleanColumnAccessor].getName
+ case ByteType => classOf[ByteColumnAccessor].getName
+ case ShortType => classOf[ShortColumnAccessor].getName
+ case IntegerType | DateType => classOf[IntColumnAccessor].getName
+ case LongType | TimestampType => classOf[LongColumnAccessor].getName
+ case FloatType => classOf[FloatColumnAccessor].getName
+ case DoubleType => classOf[DoubleColumnAccessor].getName
+ case StringType => classOf[StringColumnAccessor].getName
+ case BinaryType => classOf[BinaryColumnAccessor].getName
+ case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
+ classOf[CompactDecimalColumnAccessor].getName
+ case dt: DecimalType => classOf[DecimalColumnAccessor].getName
+ case struct: StructType => classOf[StructColumnAccessor].getName
+ case array: ArrayType => classOf[ArrayColumnAccessor].getName
+ case t: MapType => classOf[MapColumnAccessor].getName
+ }
+ ctx.addMutableState(accessorCls, accessorName, s"$accessorName = null;")
+
+ val createCode = dt match {
+ case t if ctx.isPrimitiveType(dt) =>
+ s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
+ case NullType | StringType | BinaryType =>
+ s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
+ case other =>
+ s"""$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder),
+ (${dt.getClass.getName}) columnTypes[$index]);"""
+ }
+
+ val extract = s"$accessorName.extractTo(mutableRow, $index);"
+
+ (createCode, extract)
+ }.unzip
+
+ val code = s"""
+ import java.nio.ByteBuffer;
+ import java.nio.ByteOrder;
+
+ public SpecificColumnarIterator generate($exprType[] expr) {
+ return new SpecificColumnarIterator();
+ }
+
+ class SpecificColumnarIterator extends ${classOf[ColumnarIterator].getName} {
+
+ private ByteOrder nativeOrder = null;
+ private byte[][] buffers = null;
+
+ private int currentRow = 0;
+ private int numRowsInBatch = 0;
+
+ private scala.collection.Iterator input = null;
+ private MutableRow mutableRow = null;
+ private ${classOf[DataType].getName}[] columnTypes = null;
+ private int[] columnIndexes = null;
+
+ ${declareMutableStates(ctx)}
+
+ public SpecificColumnarIterator() {
+ this.nativeOrder = ByteOrder.nativeOrder();
+ this.buffers = new byte[${columnTypes.length}][];
+
+ ${initMutableStates(ctx)}
+ }
+
+ public void initialize(scala.collection.Iterator input, MutableRow mutableRow,
+ ${classOf[DataType].getName}[] columnTypes, int[] columnIndexes) {
+ this.input = input;
+ this.mutableRow = mutableRow;
+ this.columnTypes = columnTypes;
+ this.columnIndexes = columnIndexes;
+ }
+
+ public boolean hasNext() {
+ if (currentRow < numRowsInBatch) {
+ return true;
+ }
+ if (!input.hasNext()) {
+ return false;
+ }
+
+ ${classOf[CachedBatch].getName} batch = (${classOf[CachedBatch].getName}) input.next();
+ currentRow = 0;
+ numRowsInBatch = batch.numRows();
+ for (int i = 0; i < columnIndexes.length; i ++) {
+ buffers[i] = batch.buffers()[columnIndexes[i]];
+ }
+ ${initializeAccessors.mkString("\n")}
+
+ return hasNext();
+ }
+
+ public InternalRow next() {
+ ${extractors.mkString("\n")}
+ currentRow += 1;
+ return mutableRow;
+ }
+ }"""
+
+ logDebug(s"Generated ColumnarIterator: ${CodeFormatter.format(code)}")
+
+ compile(code).generate(ctx.references.toArray).asInstanceOf[ColumnarIterator]
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index d967814f62..9f76a61a15 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.columnar
-import java.nio.ByteBuffer
-
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.rdd.RDD
@@ -28,6 +26,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
import org.apache.spark.sql.execution.{ConvertToUnsafe, LeafNode, SparkPlan}
+import org.apache.spark.sql.types.UserDefinedType
import org.apache.spark.storage.StorageLevel
import org.apache.spark.{Accumulable, Accumulator, Accumulators}
@@ -43,7 +42,14 @@ private[sql] object InMemoryRelation {
tableName)()
}
-private[sql] case class CachedBatch(buffers: Array[Array[Byte]], stats: InternalRow)
+/**
+ * CachedBatch is a cached batch of rows.
+ *
+ * @param numRows The total number of rows in this batch
+ * @param buffers The buffers for serialized columns
+ * @param stats The stat of columns
+ */
+private[sql] case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow)
private[sql] case class InMemoryRelation(
output: Seq[Attribute],
@@ -151,7 +157,7 @@ private[sql] case class InMemoryRelation(
.flatMap(_.values))
batchStats += stats
- CachedBatch(columnBuilders.map(_.build().array()), stats)
+ CachedBatch(rowCount, columnBuilders.map(_.build().array()), stats)
}
def hasNext: Boolean = rowIterator.hasNext
@@ -278,59 +284,15 @@ private[sql] case class InMemoryColumnarTableScan(
val buffers = relation.cachedColumnBuffers
buffers.mapPartitions { cachedBatchIterator =>
- val partitionFilter = newPredicate(
- partitionFilters.reduceOption(And).getOrElse(Literal(true)),
- schema)
-
- // Find the ordinals and data types of the requested columns. If none are requested, use the
- // narrowest (the field with minimum default element size).
- val (requestedColumnIndices, requestedColumnDataTypes) = if (attributes.isEmpty) {
- val (narrowestOrdinal, narrowestDataType) =
- relOutput.zipWithIndex.map { case (a, ordinal) =>
- ordinal -> a.dataType
- } minBy { case (_, dataType) =>
- ColumnType(dataType).defaultSize
- }
- Seq(narrowestOrdinal) -> Seq(narrowestDataType)
- } else {
+ val partitionFilter = newPredicate(
+ partitionFilters.reduceOption(And).getOrElse(Literal(true)),
+ schema)
+
+ // Find the ordinals and data types of the requested columns.
+ val (requestedColumnIndices, requestedColumnDataTypes) =
attributes.map { a =>
relOutput.indexWhere(_.exprId == a.exprId) -> a.dataType
}.unzip
- }
-
- val nextRow = new SpecificMutableRow(requestedColumnDataTypes)
-
- def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]): Iterator[InternalRow] = {
- val rows = cacheBatches.flatMap { cachedBatch =>
- // Build column accessors
- val columnAccessors = requestedColumnIndices.map { batchColumnIndex =>
- ColumnAccessor(
- relOutput(batchColumnIndex).dataType,
- ByteBuffer.wrap(cachedBatch.buffers(batchColumnIndex)))
- }
-
- // Extract rows via column accessors
- new Iterator[InternalRow] {
- private[this] val rowLen = nextRow.numFields
- override def next(): InternalRow = {
- var i = 0
- while (i < rowLen) {
- columnAccessors(i).extractTo(nextRow, i)
- i += 1
- }
- if (attributes.isEmpty) InternalRow.empty else nextRow
- }
-
- override def hasNext: Boolean = columnAccessors(0).hasNext
- }
- }
-
- if (rows.hasNext && enableAccumulators) {
- readPartitions += 1
- }
-
- rows
- }
// Do partition batch pruning if enabled
val cachedBatchesToScan =
@@ -355,7 +317,18 @@ private[sql] case class InMemoryColumnarTableScan(
cachedBatchIterator
}
- cachedBatchesToRows(cachedBatchesToScan)
+ val nextRow = new SpecificMutableRow(requestedColumnDataTypes)
+ val columnTypes = requestedColumnDataTypes.map {
+ case udt: UserDefinedType[_] => udt.sqlType
+ case other => other
+ }.toArray
+ val columnarIterator = GenerateColumnAccessor.generate(columnTypes)
+ columnarIterator.initialize(cachedBatchesToScan, nextRow, columnTypes,
+ requestedColumnIndices.toArray)
+ if (enableAccumulators && columnarIterator.hasNext) {
+ readPartitions += 1
+ }
+ columnarIterator
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala
index 4d35650d4b..7eaecfe047 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala
@@ -31,8 +31,8 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor {
abstract override protected def initialize(): Unit = {
nullsBuffer = underlyingBuffer.duplicate().order(ByteOrder.nativeOrder())
- nullCount = nullsBuffer.getInt()
- nextNullIndex = if (nullCount > 0) nullsBuffer.getInt() else -1
+ nullCount = ByteBufferHelper.getInt(nullsBuffer)
+ nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1
pos = 0
underlyingBuffer.position(underlyingBuffer.position + 4 + nullCount * 4)
@@ -44,7 +44,7 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor {
seenNulls += 1
if (seenNulls < nullCount) {
- nextNullIndex = nullsBuffer.getInt()
+ nextNullIndex = ByteBufferHelper.getInt(nullsBuffer)
}
row.setNullAt(ordinal)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
index ca910a99db..41c9a284e3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
@@ -20,13 +20,11 @@ package org.apache.spark.sql.columnar.compression
import java.nio.ByteBuffer
import scala.collection.mutable
-import scala.reflect.ClassTag
-import scala.reflect.runtime.universe.runtimeMirror
+
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow}
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.types._
-import org.apache.spark.util.Utils
private[sql] case object PassThrough extends CompressionScheme {
@@ -161,7 +159,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme {
override def next(row: MutableRow, ordinal: Int): Unit = {
if (valueCount == run) {
currentValue = columnType.extract(buffer)
- run = buffer.getInt()
+ run = ByteBufferHelper.getInt(buffer)
valueCount = 1
} else {
valueCount += 1
@@ -271,7 +269,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme {
extends compression.Decoder[T] {
private val dictionary: Array[Any] = {
- val elementNum = buffer.getInt()
+ val elementNum = ByteBufferHelper.getInt(buffer)
Array.fill[Any](elementNum)(columnType.extract(buffer).asInstanceOf[Any])
}
@@ -352,7 +350,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme {
}
class Decoder(buffer: ByteBuffer) extends compression.Decoder[BooleanType.type] {
- private val count = buffer.getInt()
+ private val count = ByteBufferHelper.getInt(buffer)
private var currentWord = 0: Long
@@ -363,7 +361,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme {
visited += 1
if (bit == 0) {
- currentWord = buffer.getLong()
+ currentWord = ByteBufferHelper.getLong(buffer)
}
row.setBoolean(ordinal, ((currentWord >> bit) & 1) != 0)
@@ -447,7 +445,7 @@ private[sql] case object IntDelta extends CompressionScheme {
override def next(row: MutableRow, ordinal: Int): Unit = {
val delta = buffer.get()
- prev = if (delta > Byte.MinValue) prev + delta else buffer.getInt()
+ prev = if (delta > Byte.MinValue) prev + delta else ByteBufferHelper.getInt(buffer)
row.setInt(ordinal, prev)
}
}
@@ -527,7 +525,7 @@ private[sql] case object LongDelta extends CompressionScheme {
override def next(row: MutableRow, ordinal: Int): Unit = {
val delta = buffer.get()
- prev = if (delta > Byte.MinValue) prev + delta else buffer.getLong()
+ prev = if (delta > Byte.MinValue) prev + delta else ByteBufferHelper.getLong(buffer)
row.setLong(ordinal, prev)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 4db9f4ee67..dc38fe59fe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -271,7 +271,7 @@ case class Coalesce(numPartitions: Int, child: SparkPlan) extends UnaryNode {
}
protected override def doExecute(): RDD[InternalRow] = {
- child.execute().map(_.copy()).coalesce(numPartitions, shuffle = false)
+ child.execute().coalesce(numPartitions, shuffle = false)
}
override def canProcessUnsafeRows: Boolean = true