aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorCheng Lian <lian.cs.zju@gmail.com>2014-09-13 15:08:30 -0700
committerMichael Armbrust <michael@databricks.com>2014-09-13 15:08:30 -0700
commit74049249abb952ad061c0e221c22ff894a9e9c8d (patch)
treed5a8a9880c2565b3d5477e3b9054fc8018613db2 /sql
parent184cd51c4207c23726da97f907f2d912a5a44845 (diff)
downloadspark-74049249abb952ad061c0e221c22ff894a9e9c8d.tar.gz
spark-74049249abb952ad061c0e221c22ff894a9e9c8d.tar.bz2
spark-74049249abb952ad061c0e221c22ff894a9e9c8d.zip
[SPARK-3294][SQL] Eliminates boxing costs from in-memory columnar storage
This is a major refactoring of the in-memory columnar storage implementation, aims to eliminate boxing costs from critical paths (building/accessing column buffers) as much as possible. The basic idea is to refactor all major interfaces into a row-based form and use them together with `SpecificMutableRow`. The difficult part is how to adapt all compression schemes, esp. `RunLengthEncoding` and `DictionaryEncoding`, to this design. Since in-memory compression is disabled by default for now, and this PR should be strictly better than before no matter in-memory compression is enabled or not, maybe I'll finish that part in another PR. **UPDATE** This PR also took the chance to optimize `HiveTableScan` by 1. leveraging `SpecificMutableRow` to avoid boxing cost, and 1. building specific `Writable` unwrapper functions a head of time to avoid per row pattern matching and branching costs. TODO - [x] Benchmark - [ ] ~~Eliminate boxing costs in `RunLengthEncoding`~~ (left to future PRs) - [ ] ~~Eliminate boxing costs in `DictionaryEncoding` (seems not easy to do without specializing `DictionaryEncoding` for every supported column type)~~ (left to future PRs) ## Micro benchmark The benchmark uses a 10 million line CSV table consists of bytes, shorts, integers, longs, floats and doubles, measures the time to build the in-memory version of this table, and the time to scan the whole in-memory table. Benchmark code can be found [here](https://gist.github.com/liancheng/fe70a148de82e77bd2c8#file-hivetablescanbenchmark-scala). Script used to generate the input table can be found [here](https://gist.github.com/liancheng/fe70a148de82e77bd2c8#file-tablegen-scala). Speedup: - Hive table scanning + column buffer building: **18.74%** The original benchmark uses 1K as in-memory batch size, when increased to 10K, it can be 28.32% faster. - In-memory table scanning: **7.95%** Before: | Building | Scanning ------- | -------- | -------- 1 | 16472 | 525 2 | 16168 | 530 3 | 16386 | 529 4 | 16184 | 538 5 | 16209 | 521 Average | 16283.8 | 528.6 After: | Building | Scanning ------- | -------- | -------- 1 | 13124 | 458 2 | 13260 | 529 3 | 12981 | 463 4 | 13214 | 483 5 | 13583 | 500 Average | 13232.4 | 486.6 Author: Cheng Lian <lian.cs.zju@gmail.com> Closes #2327 from liancheng/prevent-boxing/unboxing and squashes the following commits: 4419fe4 [Cheng Lian] Addressing comments e5d2cf2 [Cheng Lian] Bug fix: should call setNullAt when field value is null to avoid NPE 8b8552b [Cheng Lian] Only checks for partition batch pruning flag once 489f97b [Cheng Lian] Bug fix: TableReader.fillObject uses wrong ordinals 97bbc4e [Cheng Lian] Optimizes hive.TableReader by by providing specific Writable unwrappers a head of time 3dc1f94 [Cheng Lian] Minor changes to eliminate row object creation 5b39cb9 [Cheng Lian] Lowers log level of compression scheme details f2a7890 [Cheng Lian] Use SpecificMutableRow in InMemoryColumnarTableScan to avoid boxing 9cf30b0 [Cheng Lian] Added row based ColumnType.append/extract 456c366 [Cheng Lian] Made compression decoder row based edac3cd [Cheng Lian] Makes ColumnAccessor.extractSingle row based 8216936 [Cheng Lian] Removes boxing cost in IntDelta and LongDelta by providing specialized implementations b70d519 [Cheng Lian] Made some in-memory columnar storage interfaces row-based
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala27
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala178
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala92
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala24
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala264
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala9
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala119
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala18
24 files changed, 554 insertions, 292 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala
index 088f11ee4a..9cbab3d5d0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala
@@ -171,7 +171,7 @@ final class MutableByte extends MutableValue {
}
final class MutableAny extends MutableValue {
- var value: Any = 0
+ var value: Any = _
def boxed = if (isNull) null else value
def update(v: Any) = value = {
isNull = false
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
index 42a5a9a84f..c9faf08521 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
@@ -50,11 +50,13 @@ private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType](
def hasNext = buffer.hasRemaining
- def extractTo(row: MutableRow, ordinal: Int) {
- columnType.setField(row, ordinal, extractSingle(buffer))
+ def extractTo(row: MutableRow, ordinal: Int): Unit = {
+ extractSingle(row, ordinal)
}
- def extractSingle(buffer: ByteBuffer): JvmType = columnType.extract(buffer)
+ def extractSingle(row: MutableRow, ordinal: Int): Unit = {
+ columnType.extract(buffer, row, ordinal)
+ }
protected def underlyingBuffer = buffer
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
index b3ec5ded22..2e61a98137 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
@@ -68,10 +68,9 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType](
buffer.order(ByteOrder.nativeOrder()).putInt(columnType.typeId)
}
- override def appendFrom(row: Row, ordinal: Int) {
- val field = columnType.getField(row, ordinal)
- buffer = ensureFreeSpace(buffer, columnType.actualSize(field))
- columnType.append(field, buffer)
+ override def appendFrom(row: Row, ordinal: Int): Unit = {
+ buffer = ensureFreeSpace(buffer, columnType.actualSize(row, ordinal))
+ columnType.append(row, ordinal, buffer)
}
override def build() = {
@@ -142,16 +141,16 @@ private[sql] object ColumnBuilder {
useCompression: Boolean = false): ColumnBuilder = {
val builder = (typeId match {
- case INT.typeId => new IntColumnBuilder
- case LONG.typeId => new LongColumnBuilder
- case FLOAT.typeId => new FloatColumnBuilder
- case DOUBLE.typeId => new DoubleColumnBuilder
- case BOOLEAN.typeId => new BooleanColumnBuilder
- case BYTE.typeId => new ByteColumnBuilder
- case SHORT.typeId => new ShortColumnBuilder
- case STRING.typeId => new StringColumnBuilder
- case BINARY.typeId => new BinaryColumnBuilder
- case GENERIC.typeId => new GenericColumnBuilder
+ case INT.typeId => new IntColumnBuilder
+ case LONG.typeId => new LongColumnBuilder
+ case FLOAT.typeId => new FloatColumnBuilder
+ case DOUBLE.typeId => new DoubleColumnBuilder
+ case BOOLEAN.typeId => new BooleanColumnBuilder
+ case BYTE.typeId => new ByteColumnBuilder
+ case SHORT.typeId => new ShortColumnBuilder
+ case STRING.typeId => new StringColumnBuilder
+ case BINARY.typeId => new BinaryColumnBuilder
+ case GENERIC.typeId => new GenericColumnBuilder
case TIMESTAMP.typeId => new TimestampColumnBuilder
}).asInstanceOf[ColumnBuilder]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
index fc343ccb99..203a714e03 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -69,7 +69,7 @@ private[sql] class ByteColumnStats extends ColumnStats {
var lower = Byte.MaxValue
var nullCount = 0
- override def gatherStats(row: Row, ordinal: Int) {
+ override def gatherStats(row: Row, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getByte(ordinal)
if (value > upper) upper = value
@@ -87,7 +87,7 @@ private[sql] class ShortColumnStats extends ColumnStats {
var lower = Short.MaxValue
var nullCount = 0
- override def gatherStats(row: Row, ordinal: Int) {
+ override def gatherStats(row: Row, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getShort(ordinal)
if (value > upper) upper = value
@@ -105,7 +105,7 @@ private[sql] class LongColumnStats extends ColumnStats {
var lower = Long.MaxValue
var nullCount = 0
- override def gatherStats(row: Row, ordinal: Int) {
+ override def gatherStats(row: Row, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getLong(ordinal)
if (value > upper) upper = value
@@ -123,7 +123,7 @@ private[sql] class DoubleColumnStats extends ColumnStats {
var lower = Double.MaxValue
var nullCount = 0
- override def gatherStats(row: Row, ordinal: Int) {
+ override def gatherStats(row: Row, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getDouble(ordinal)
if (value > upper) upper = value
@@ -141,7 +141,7 @@ private[sql] class FloatColumnStats extends ColumnStats {
var lower = Float.MaxValue
var nullCount = 0
- override def gatherStats(row: Row, ordinal: Int) {
+ override def gatherStats(row: Row, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getFloat(ordinal)
if (value > upper) upper = value
@@ -159,7 +159,7 @@ private[sql] class IntColumnStats extends ColumnStats {
var lower = Int.MaxValue
var nullCount = 0
- override def gatherStats(row: Row, ordinal: Int) {
+ override def gatherStats(row: Row, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getInt(ordinal)
if (value > upper) upper = value
@@ -177,7 +177,7 @@ private[sql] class StringColumnStats extends ColumnStats {
var lower: String = null
var nullCount = 0
- override def gatherStats(row: Row, ordinal: Int) {
+ override def gatherStats(row: Row, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getString(ordinal)
if (upper == null || value.compareTo(upper) > 0) upper = value
@@ -195,7 +195,7 @@ private[sql] class TimestampColumnStats extends ColumnStats {
var lower: Timestamp = null
var nullCount = 0
- override def gatherStats(row: Row, ordinal: Int) {
+ override def gatherStats(row: Row, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row(ordinal).asInstanceOf[Timestamp]
if (upper == null || value.compareTo(upper) > 0) upper = value
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index 9a61600115..198b575667 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -18,11 +18,10 @@
package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
+import java.sql.Timestamp
import scala.reflect.runtime.universe.TypeTag
-import java.sql.Timestamp
-
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.MutableRow
import org.apache.spark.sql.catalyst.types._
@@ -47,15 +46,32 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType](
def extract(buffer: ByteBuffer): JvmType
/**
+ * Extracts a value out of the buffer at the buffer's current position and stores in
+ * `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs whenever
+ * possible.
+ */
+ def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
+ setField(row, ordinal, extract(buffer))
+ }
+
+ /**
* Appends the given value v of type T into the given ByteBuffer.
*/
- def append(v: JvmType, buffer: ByteBuffer)
+ def append(v: JvmType, buffer: ByteBuffer): Unit
+
+ /**
+ * Appends `row(ordinal)` of type T into the given ByteBuffer. Subclasses should override this
+ * method to avoid boxing/unboxing costs whenever possible.
+ */
+ def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = {
+ append(getField(row, ordinal), buffer)
+ }
/**
- * Returns the size of the value. This is used to calculate the size of variable length types
- * such as byte arrays and strings.
+ * Returns the size of the value `row(ordinal)`. This is used to calculate the size of variable
+ * length types such as byte arrays and strings.
*/
- def actualSize(v: JvmType): Int = defaultSize
+ def actualSize(row: Row, ordinal: Int): Int = defaultSize
/**
* Returns `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs
@@ -67,7 +83,15 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType](
* Sets `row(ordinal)` to `field`. Subclasses should override this method to avoid boxing/unboxing
* costs whenever possible.
*/
- def setField(row: MutableRow, ordinal: Int, value: JvmType)
+ def setField(row: MutableRow, ordinal: Int, value: JvmType): Unit
+
+ /**
+ * Copies `from(fromOrdinal)` to `to(toOrdinal)`. Subclasses should override this method to avoid
+ * boxing/unboxing costs whenever possible.
+ */
+ def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
+ to(toOrdinal) = from(fromOrdinal)
+ }
/**
* Creates a duplicated copy of the value.
@@ -90,119 +114,205 @@ private[sql] abstract class NativeColumnType[T <: NativeType](
}
private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) {
- def append(v: Int, buffer: ByteBuffer) {
+ def append(v: Int, buffer: ByteBuffer): Unit = {
buffer.putInt(v)
}
+ override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = {
+ buffer.putInt(row.getInt(ordinal))
+ }
+
def extract(buffer: ByteBuffer) = {
buffer.getInt()
}
- override def setField(row: MutableRow, ordinal: Int, value: Int) {
+ override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
+ row.setInt(ordinal, buffer.getInt())
+ }
+
+ override def setField(row: MutableRow, ordinal: Int, value: Int): Unit = {
row.setInt(ordinal, value)
}
override def getField(row: Row, ordinal: Int) = row.getInt(ordinal)
+
+ override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
+ to.setInt(toOrdinal, from.getInt(fromOrdinal))
+ }
}
private[sql] object LONG extends NativeColumnType(LongType, 1, 8) {
- override def append(v: Long, buffer: ByteBuffer) {
+ override def append(v: Long, buffer: ByteBuffer): Unit = {
buffer.putLong(v)
}
+ override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = {
+ buffer.putLong(row.getLong(ordinal))
+ }
+
override def extract(buffer: ByteBuffer) = {
buffer.getLong()
}
- override def setField(row: MutableRow, ordinal: Int, value: Long) {
+ override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
+ row.setLong(ordinal, buffer.getLong())
+ }
+
+ override def setField(row: MutableRow, ordinal: Int, value: Long): Unit = {
row.setLong(ordinal, value)
}
override def getField(row: Row, ordinal: Int) = row.getLong(ordinal)
+
+ override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
+ to.setLong(toOrdinal, from.getLong(fromOrdinal))
+ }
}
private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) {
- override def append(v: Float, buffer: ByteBuffer) {
+ override def append(v: Float, buffer: ByteBuffer): Unit = {
buffer.putFloat(v)
}
+ override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = {
+ buffer.putFloat(row.getFloat(ordinal))
+ }
+
override def extract(buffer: ByteBuffer) = {
buffer.getFloat()
}
- override def setField(row: MutableRow, ordinal: Int, value: Float) {
+ override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
+ row.setFloat(ordinal, buffer.getFloat())
+ }
+
+ override def setField(row: MutableRow, ordinal: Int, value: Float): Unit = {
row.setFloat(ordinal, value)
}
override def getField(row: Row, ordinal: Int) = row.getFloat(ordinal)
+
+ override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
+ to.setFloat(toOrdinal, from.getFloat(fromOrdinal))
+ }
}
private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) {
- override def append(v: Double, buffer: ByteBuffer) {
+ override def append(v: Double, buffer: ByteBuffer): Unit = {
buffer.putDouble(v)
}
+ override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = {
+ buffer.putDouble(row.getDouble(ordinal))
+ }
+
override def extract(buffer: ByteBuffer) = {
buffer.getDouble()
}
- override def setField(row: MutableRow, ordinal: Int, value: Double) {
+ override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
+ row.setDouble(ordinal, buffer.getDouble())
+ }
+
+ override def setField(row: MutableRow, ordinal: Int, value: Double): Unit = {
row.setDouble(ordinal, value)
}
override def getField(row: Row, ordinal: Int) = row.getDouble(ordinal)
+
+ override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
+ to.setDouble(toOrdinal, from.getDouble(fromOrdinal))
+ }
}
private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) {
- override def append(v: Boolean, buffer: ByteBuffer) {
- buffer.put(if (v) 1.toByte else 0.toByte)
+ override def append(v: Boolean, buffer: ByteBuffer): Unit = {
+ buffer.put(if (v) 1: Byte else 0: Byte)
+ }
+
+ override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = {
+ buffer.put(if (row.getBoolean(ordinal)) 1: Byte else 0: Byte)
}
override def extract(buffer: ByteBuffer) = buffer.get() == 1
- override def setField(row: MutableRow, ordinal: Int, value: Boolean) {
+ override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
+ row.setBoolean(ordinal, buffer.get() == 1)
+ }
+
+ override def setField(row: MutableRow, ordinal: Int, value: Boolean): Unit = {
row.setBoolean(ordinal, value)
}
override def getField(row: Row, ordinal: Int) = row.getBoolean(ordinal)
+
+ override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
+ to.setBoolean(toOrdinal, from.getBoolean(fromOrdinal))
+ }
}
private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) {
- override def append(v: Byte, buffer: ByteBuffer) {
+ override def append(v: Byte, buffer: ByteBuffer): Unit = {
buffer.put(v)
}
+ override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = {
+ buffer.put(row.getByte(ordinal))
+ }
+
override def extract(buffer: ByteBuffer) = {
buffer.get()
}
- override def setField(row: MutableRow, ordinal: Int, value: Byte) {
+ override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
+ row.setByte(ordinal, buffer.get())
+ }
+
+ override def setField(row: MutableRow, ordinal: Int, value: Byte): Unit = {
row.setByte(ordinal, value)
}
override def getField(row: Row, ordinal: Int) = row.getByte(ordinal)
+
+ override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
+ to.setByte(toOrdinal, from.getByte(fromOrdinal))
+ }
}
private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) {
- override def append(v: Short, buffer: ByteBuffer) {
+ override def append(v: Short, buffer: ByteBuffer): Unit = {
buffer.putShort(v)
}
+ override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = {
+ buffer.putShort(row.getShort(ordinal))
+ }
+
override def extract(buffer: ByteBuffer) = {
buffer.getShort()
}
- override def setField(row: MutableRow, ordinal: Int, value: Short) {
+ override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
+ row.setShort(ordinal, buffer.getShort())
+ }
+
+ override def setField(row: MutableRow, ordinal: Int, value: Short): Unit = {
row.setShort(ordinal, value)
}
override def getField(row: Row, ordinal: Int) = row.getShort(ordinal)
+
+ override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
+ to.setShort(toOrdinal, from.getShort(fromOrdinal))
+ }
}
private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
- override def actualSize(v: String): Int = v.getBytes("utf-8").length + 4
+ override def actualSize(row: Row, ordinal: Int): Int = {
+ row.getString(ordinal).getBytes("utf-8").length + 4
+ }
- override def append(v: String, buffer: ByteBuffer) {
+ override def append(v: String, buffer: ByteBuffer): Unit = {
val stringBytes = v.getBytes("utf-8")
buffer.putInt(stringBytes.length).put(stringBytes, 0, stringBytes.length)
}
@@ -214,11 +324,15 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
new String(stringBytes, "utf-8")
}
- override def setField(row: MutableRow, ordinal: Int, value: String) {
+ override def setField(row: MutableRow, ordinal: Int, value: String): Unit = {
row.setString(ordinal, value)
}
override def getField(row: Row, ordinal: Int) = row.getString(ordinal)
+
+ override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
+ to.setString(toOrdinal, from.getString(fromOrdinal))
+ }
}
private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 8, 12) {
@@ -228,7 +342,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 8, 12) {
timestamp
}
- override def append(v: Timestamp, buffer: ByteBuffer) {
+ override def append(v: Timestamp, buffer: ByteBuffer): Unit = {
buffer.putLong(v.getTime).putInt(v.getNanos)
}
@@ -236,7 +350,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 8, 12) {
row(ordinal).asInstanceOf[Timestamp]
}
- override def setField(row: MutableRow, ordinal: Int, value: Timestamp) {
+ override def setField(row: MutableRow, ordinal: Int, value: Timestamp): Unit = {
row(ordinal) = value
}
}
@@ -246,9 +360,11 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
defaultSize: Int)
extends ColumnType[T, Array[Byte]](typeId, defaultSize) {
- override def actualSize(v: Array[Byte]) = v.length + 4
+ override def actualSize(row: Row, ordinal: Int) = {
+ getField(row, ordinal).length + 4
+ }
- override def append(v: Array[Byte], buffer: ByteBuffer) {
+ override def append(v: Array[Byte], buffer: ByteBuffer): Unit = {
buffer.putInt(v.length).put(v, 0, v.length)
}
@@ -261,7 +377,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
}
private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](9, 16) {
- override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) {
+ override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = {
row(ordinal) = value
}
@@ -272,7 +388,7 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](9, 16) {
// serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized
// byte array.
private[sql] object GENERIC extends ByteArrayColumnType[DataType](10, 16) {
- override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) {
+ override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = {
row(ordinal) = SparkSqlSerializer.deserialize[Any](value)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index 6eab2f23c1..8a3612cdf1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -52,7 +52,7 @@ private[sql] case class InMemoryRelation(
// As in Spark, the actual work of caching is lazy.
if (_cachedColumnBuffers == null) {
val output = child.output
- val cached = child.execute().mapPartitions { baseIterator =>
+ val cached = child.execute().mapPartitions { rowIterator =>
new Iterator[CachedBatch] {
def next() = {
val columnBuilders = output.map { attribute =>
@@ -61,11 +61,9 @@ private[sql] case class InMemoryRelation(
ColumnBuilder(columnType.typeId, initialBufferSize, attribute.name, useCompression)
}.toArray
- var row: Row = null
var rowCount = 0
-
- while (baseIterator.hasNext && rowCount < batchSize) {
- row = baseIterator.next()
+ while (rowIterator.hasNext && rowCount < batchSize) {
+ val row = rowIterator.next()
var i = 0
while (i < row.length) {
columnBuilders(i).appendFrom(row, i)
@@ -80,7 +78,7 @@ private[sql] case class InMemoryRelation(
CachedBatch(columnBuilders.map(_.build()), stats)
}
- def hasNext = baseIterator.hasNext
+ def hasNext = rowIterator.hasNext
}
}.cache()
@@ -182,6 +180,7 @@ private[sql] case class InMemoryColumnarTableScan(
}
}
+ // Accumulators used for testing purposes
val readPartitions = sparkContext.accumulator(0)
val readBatches = sparkContext.accumulator(0)
@@ -191,40 +190,36 @@ private[sql] case class InMemoryColumnarTableScan(
readPartitions.setValue(0)
readBatches.setValue(0)
- relation.cachedColumnBuffers.mapPartitions { iterator =>
+ relation.cachedColumnBuffers.mapPartitions { cachedBatchIterator =>
val partitionFilter = newPredicate(
partitionFilters.reduceOption(And).getOrElse(Literal(true)),
relation.partitionStatistics.schema)
- // Find the ordinals of the requested columns. If none are requested, use the first.
- val requestedColumns = if (attributes.isEmpty) {
- Seq(0)
+ // Find the ordinals and data types of the requested columns. If none are requested, use the
+ // narrowest (the field with minimum default element size).
+ val (requestedColumnIndices, requestedColumnDataTypes) = if (attributes.isEmpty) {
+ val (narrowestOrdinal, narrowestDataType) =
+ relation.output.zipWithIndex.map { case (a, ordinal) =>
+ ordinal -> a.dataType
+ } minBy { case (_, dataType) =>
+ ColumnType(dataType).defaultSize
+ }
+ Seq(narrowestOrdinal) -> Seq(narrowestDataType)
} else {
- attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId))
+ attributes.map { a =>
+ relation.output.indexWhere(_.exprId == a.exprId) -> a.dataType
+ }.unzip
}
- val rows = iterator
- // Skip pruned batches
- .filter { cachedBatch =>
- if (inMemoryPartitionPruningEnabled && !partitionFilter(cachedBatch.stats)) {
- def statsString = relation.partitionStatistics.schema
- .zip(cachedBatch.stats)
- .map { case (a, s) => s"${a.name}: $s" }
- .mkString(", ")
- logInfo(s"Skipping partition based on stats $statsString")
- false
- } else {
- readBatches += 1
- true
- }
- }
- // Build column accessors
- .map { cachedBatch =>
- requestedColumns.map(cachedBatch.buffers(_)).map(ColumnAccessor(_))
- }
- // Extract rows via column accessors
- .flatMap { columnAccessors =>
- val nextRow = new GenericMutableRow(columnAccessors.length)
+ val nextRow = new SpecificMutableRow(requestedColumnDataTypes)
+
+ def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]) = {
+ val rows = cacheBatches.flatMap { cachedBatch =>
+ // Build column accessors
+ val columnAccessors =
+ requestedColumnIndices.map(cachedBatch.buffers(_)).map(ColumnAccessor(_))
+
+ // Extract rows via column accessors
new Iterator[Row] {
override def next() = {
var i = 0
@@ -235,15 +230,38 @@ private[sql] case class InMemoryColumnarTableScan(
nextRow
}
- override def hasNext = columnAccessors.head.hasNext
+ override def hasNext = columnAccessors(0).hasNext
}
}
- if (rows.hasNext) {
- readPartitions += 1
+ if (rows.hasNext) {
+ readPartitions += 1
+ }
+
+ rows
}
- rows
+ // Do partition batch pruning if enabled
+ val cachedBatchesToScan =
+ if (inMemoryPartitionPruningEnabled) {
+ cachedBatchIterator.filter { cachedBatch =>
+ if (!partitionFilter(cachedBatch.stats)) {
+ def statsString = relation.partitionStatistics.schema
+ .zip(cachedBatch.stats)
+ .map { case (a, s) => s"${a.name}: $s" }
+ .mkString(", ")
+ logInfo(s"Skipping partition based on stats $statsString")
+ false
+ } else {
+ readBatches += 1
+ true
+ }
+ }
+ } else {
+ cachedBatchIterator
+ }
+
+ cachedBatchesToRows(cachedBatchesToScan)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala
index b7f8826861..965782a400 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala
@@ -29,7 +29,7 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor {
private var nextNullIndex: Int = _
private var pos: Int = 0
- abstract override protected def initialize() {
+ abstract override protected def initialize(): Unit = {
nullsBuffer = underlyingBuffer.duplicate().order(ByteOrder.nativeOrder())
nullCount = nullsBuffer.getInt()
nextNullIndex = if (nullCount > 0) nullsBuffer.getInt() else -1
@@ -39,7 +39,7 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor {
super.initialize()
}
- abstract override def extractTo(row: MutableRow, ordinal: Int) {
+ abstract override def extractTo(row: MutableRow, ordinal: Int): Unit = {
if (pos == nextNullIndex) {
seenNulls += 1
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
index a72970eef7..f1f494ac26 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
@@ -40,7 +40,11 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder {
protected var nullCount: Int = _
private var pos: Int = _
- abstract override def initialize(initialSize: Int, columnName: String, useCompression: Boolean) {
+ abstract override def initialize(
+ initialSize: Int,
+ columnName: String,
+ useCompression: Boolean): Unit = {
+
nulls = ByteBuffer.allocate(1024)
nulls.order(ByteOrder.nativeOrder())
pos = 0
@@ -48,7 +52,7 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder {
super.initialize(initialSize, columnName, useCompression)
}
- abstract override def appendFrom(row: Row, ordinal: Int) {
+ abstract override def appendFrom(row: Row, ordinal: Int): Unit = {
columnStats.gatherStats(row, ordinal)
if (row.isNullAt(ordinal)) {
nulls = ColumnBuilder.ensureFreeSpace(nulls, 4)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala
index b4120a3d43..27ac5f4dbd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala
@@ -17,8 +17,7 @@
package org.apache.spark.sql.columnar.compression
-import java.nio.ByteBuffer
-
+import org.apache.spark.sql.catalyst.expressions.MutableRow
import org.apache.spark.sql.catalyst.types.NativeType
import org.apache.spark.sql.columnar.{ColumnAccessor, NativeColumnAccessor}
@@ -34,5 +33,7 @@ private[sql] trait CompressibleColumnAccessor[T <: NativeType] extends ColumnAcc
abstract override def hasNext = super.hasNext || decoder.hasNext
- override def extractSingle(buffer: ByteBuffer): T#JvmType = decoder.next()
+ override def extractSingle(row: MutableRow, ordinal: Int): Unit = {
+ decoder.next(row, ordinal)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala
index a5826bb033..628d9cec41 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala
@@ -48,12 +48,16 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType]
var compressionEncoders: Seq[Encoder[T]] = _
- abstract override def initialize(initialSize: Int, columnName: String, useCompression: Boolean) {
+ abstract override def initialize(
+ initialSize: Int,
+ columnName: String,
+ useCompression: Boolean): Unit = {
+
compressionEncoders =
if (useCompression) {
- schemes.filter(_.supports(columnType)).map(_.encoder[T])
+ schemes.filter(_.supports(columnType)).map(_.encoder[T](columnType))
} else {
- Seq(PassThrough.encoder)
+ Seq(PassThrough.encoder(columnType))
}
super.initialize(initialSize, columnName, useCompression)
}
@@ -62,17 +66,15 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType]
encoder.compressionRatio < 0.8
}
- private def gatherCompressibilityStats(row: Row, ordinal: Int) {
- val field = columnType.getField(row, ordinal)
-
+ private def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {
var i = 0
while (i < compressionEncoders.length) {
- compressionEncoders(i).gatherCompressibilityStats(field, columnType)
+ compressionEncoders(i).gatherCompressibilityStats(row, ordinal)
i += 1
}
}
- abstract override def appendFrom(row: Row, ordinal: Int) {
+ abstract override def appendFrom(row: Row, ordinal: Int): Unit = {
super.appendFrom(row, ordinal)
if (!row.isNullAt(ordinal)) {
gatherCompressibilityStats(row, ordinal)
@@ -84,7 +86,7 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType]
val typeId = nonNullBuffer.getInt()
val encoder: Encoder[T] = {
val candidate = compressionEncoders.minBy(_.compressionRatio)
- if (isWorthCompressing(candidate)) candidate else PassThrough.encoder
+ if (isWorthCompressing(candidate)) candidate else PassThrough.encoder(columnType)
}
// Header = column type ID + null count + null positions
@@ -104,7 +106,7 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType]
.putInt(nullCount)
.put(nulls)
- logInfo(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}")
- encoder.compress(nonNullBuffer, compressedBuffer, columnType)
+ logDebug(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}")
+ encoder.compress(nonNullBuffer, compressedBuffer)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala
index 7797f75177..acb06cb537 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala
@@ -17,13 +17,15 @@
package org.apache.spark.sql.columnar.compression
-import java.nio.{ByteOrder, ByteBuffer}
+import java.nio.{ByteBuffer, ByteOrder}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.MutableRow
import org.apache.spark.sql.catalyst.types.NativeType
import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType}
private[sql] trait Encoder[T <: NativeType] {
- def gatherCompressibilityStats(value: T#JvmType, columnType: NativeColumnType[T]) {}
+ def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {}
def compressedSize: Int
@@ -33,17 +35,21 @@ private[sql] trait Encoder[T <: NativeType] {
if (uncompressedSize > 0) compressedSize.toDouble / uncompressedSize else 1.0
}
- def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]): ByteBuffer
+ def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer
}
-private[sql] trait Decoder[T <: NativeType] extends Iterator[T#JvmType]
+private[sql] trait Decoder[T <: NativeType] {
+ def next(row: MutableRow, ordinal: Int): Unit
+
+ def hasNext: Boolean
+}
private[sql] trait CompressionScheme {
def typeId: Int
def supports(columnType: ColumnType[_, _]): Boolean
- def encoder[T <: NativeType]: Encoder[T]
+ def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T]
def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
index 8cf9ec74ca..29edcf1724 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
@@ -23,7 +23,8 @@ import scala.collection.mutable
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.runtimeMirror
-import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow}
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.columnar._
import org.apache.spark.util.Utils
@@ -33,18 +34,20 @@ private[sql] case object PassThrough extends CompressionScheme {
override def supports(columnType: ColumnType[_, _]) = true
- override def encoder[T <: NativeType] = new this.Encoder[T]
+ override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = {
+ new this.Encoder[T](columnType)
+ }
override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = {
new this.Decoder(buffer, columnType)
}
- class Encoder[T <: NativeType] extends compression.Encoder[T] {
+ class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] {
override def uncompressedSize = 0
override def compressedSize = 0
- override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]) = {
+ override def compress(from: ByteBuffer, to: ByteBuffer) = {
// Writes compression type ID and copies raw contents
to.putInt(PassThrough.typeId).put(from).rewind()
to
@@ -54,7 +57,9 @@ private[sql] case object PassThrough extends CompressionScheme {
class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T])
extends compression.Decoder[T] {
- override def next() = columnType.extract(buffer)
+ override def next(row: MutableRow, ordinal: Int): Unit = {
+ columnType.extract(buffer, row, ordinal)
+ }
override def hasNext = buffer.hasRemaining
}
@@ -63,7 +68,9 @@ private[sql] case object PassThrough extends CompressionScheme {
private[sql] case object RunLengthEncoding extends CompressionScheme {
override val typeId = 1
- override def encoder[T <: NativeType] = new this.Encoder[T]
+ override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = {
+ new this.Encoder[T](columnType)
+ }
override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = {
new this.Decoder(buffer, columnType)
@@ -74,24 +81,25 @@ private[sql] case object RunLengthEncoding extends CompressionScheme {
case _ => false
}
- class Encoder[T <: NativeType] extends compression.Encoder[T] {
+ class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] {
private var _uncompressedSize = 0
private var _compressedSize = 0
// Using `MutableRow` to store the last value to avoid boxing/unboxing cost.
- private val lastValue = new GenericMutableRow(1)
+ private val lastValue = new SpecificMutableRow(Seq(columnType.dataType))
private var lastRun = 0
override def uncompressedSize = _uncompressedSize
override def compressedSize = _compressedSize
- override def gatherCompressibilityStats(value: T#JvmType, columnType: NativeColumnType[T]) {
- val actualSize = columnType.actualSize(value)
+ override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {
+ val value = columnType.getField(row, ordinal)
+ val actualSize = columnType.actualSize(row, ordinal)
_uncompressedSize += actualSize
if (lastValue.isNullAt(0)) {
- columnType.setField(lastValue, 0, value)
+ columnType.copyField(row, ordinal, lastValue, 0)
lastRun = 1
_compressedSize += actualSize + 4
} else {
@@ -99,37 +107,40 @@ private[sql] case object RunLengthEncoding extends CompressionScheme {
lastRun += 1
} else {
_compressedSize += actualSize + 4
- columnType.setField(lastValue, 0, value)
+ columnType.copyField(row, ordinal, lastValue, 0)
lastRun = 1
}
}
}
- override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]) = {
+ override def compress(from: ByteBuffer, to: ByteBuffer) = {
to.putInt(RunLengthEncoding.typeId)
if (from.hasRemaining) {
- var currentValue = columnType.extract(from)
+ val currentValue = new SpecificMutableRow(Seq(columnType.dataType))
var currentRun = 1
+ val value = new SpecificMutableRow(Seq(columnType.dataType))
+
+ columnType.extract(from, currentValue, 0)
while (from.hasRemaining) {
- val value = columnType.extract(from)
+ columnType.extract(from, value, 0)
- if (value == currentValue) {
+ if (value.head == currentValue.head) {
currentRun += 1
} else {
// Writes current run
- columnType.append(currentValue, to)
+ columnType.append(currentValue, 0, to)
to.putInt(currentRun)
// Resets current run
- currentValue = value
+ columnType.copyField(value, 0, currentValue, 0)
currentRun = 1
}
}
// Writes the last run
- columnType.append(currentValue, to)
+ columnType.append(currentValue, 0, to)
to.putInt(currentRun)
}
@@ -145,7 +156,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme {
private var valueCount = 0
private var currentValue: T#JvmType = _
- override def next() = {
+ override def next(row: MutableRow, ordinal: Int): Unit = {
if (valueCount == run) {
currentValue = columnType.extract(buffer)
run = buffer.getInt()
@@ -154,7 +165,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme {
valueCount += 1
}
- currentValue
+ columnType.setField(row, ordinal, currentValue)
}
override def hasNext = valueCount < run || buffer.hasRemaining
@@ -171,14 +182,16 @@ private[sql] case object DictionaryEncoding extends CompressionScheme {
new this.Decoder(buffer, columnType)
}
- override def encoder[T <: NativeType] = new this.Encoder[T]
+ override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = {
+ new this.Encoder[T](columnType)
+ }
override def supports(columnType: ColumnType[_, _]) = columnType match {
case INT | LONG | STRING => true
case _ => false
}
- class Encoder[T <: NativeType] extends compression.Encoder[T] {
+ class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] {
// Size of the input, uncompressed, in bytes. Note that we only count until the dictionary
// overflows.
private var _uncompressedSize = 0
@@ -200,9 +213,11 @@ private[sql] case object DictionaryEncoding extends CompressionScheme {
// to store dictionary element count.
private var dictionarySize = 4
- override def gatherCompressibilityStats(value: T#JvmType, columnType: NativeColumnType[T]) {
+ override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {
+ val value = columnType.getField(row, ordinal)
+
if (!overflow) {
- val actualSize = columnType.actualSize(value)
+ val actualSize = columnType.actualSize(row, ordinal)
count += 1
_uncompressedSize += actualSize
@@ -221,7 +236,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme {
}
}
- override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]) = {
+ override def compress(from: ByteBuffer, to: ByteBuffer) = {
if (overflow) {
throw new IllegalStateException(
"Dictionary encoding should not be used because of dictionary overflow.")
@@ -264,7 +279,9 @@ private[sql] case object DictionaryEncoding extends CompressionScheme {
}
}
- override def next() = dictionary(buffer.getShort())
+ override def next(row: MutableRow, ordinal: Int): Unit = {
+ columnType.setField(row, ordinal, dictionary(buffer.getShort()))
+ }
override def hasNext = buffer.hasRemaining
}
@@ -279,25 +296,20 @@ private[sql] case object BooleanBitSet extends CompressionScheme {
new this.Decoder(buffer).asInstanceOf[compression.Decoder[T]]
}
- override def encoder[T <: NativeType] = (new this.Encoder).asInstanceOf[compression.Encoder[T]]
+ override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = {
+ (new this.Encoder).asInstanceOf[compression.Encoder[T]]
+ }
override def supports(columnType: ColumnType[_, _]) = columnType == BOOLEAN
class Encoder extends compression.Encoder[BooleanType.type] {
private var _uncompressedSize = 0
- override def gatherCompressibilityStats(
- value: Boolean,
- columnType: NativeColumnType[BooleanType.type]) {
-
+ override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {
_uncompressedSize += BOOLEAN.defaultSize
}
- override def compress(
- from: ByteBuffer,
- to: ByteBuffer,
- columnType: NativeColumnType[BooleanType.type]) = {
-
+ override def compress(from: ByteBuffer, to: ByteBuffer) = {
to.putInt(BooleanBitSet.typeId)
// Total element count (1 byte per Boolean value)
.putInt(from.remaining)
@@ -349,7 +361,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme {
private var visited: Int = 0
- override def next(): Boolean = {
+ override def next(row: MutableRow, ordinal: Int): Unit = {
val bit = visited % BITS_PER_LONG
visited += 1
@@ -357,123 +369,167 @@ private[sql] case object BooleanBitSet extends CompressionScheme {
currentWord = buffer.getLong()
}
- ((currentWord >> bit) & 1) != 0
+ row.setBoolean(ordinal, ((currentWord >> bit) & 1) != 0)
}
override def hasNext: Boolean = visited < count
}
}
-private[sql] sealed abstract class IntegralDelta[I <: IntegralType] extends CompressionScheme {
+private[sql] case object IntDelta extends CompressionScheme {
+ override def typeId: Int = 4
+
override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = {
- new this.Decoder(buffer, columnType.asInstanceOf[NativeColumnType[I]])
- .asInstanceOf[compression.Decoder[T]]
+ new Decoder(buffer, INT).asInstanceOf[compression.Decoder[T]]
}
- override def encoder[T <: NativeType] = (new this.Encoder).asInstanceOf[compression.Encoder[T]]
-
- /**
- * Computes `delta = x - y`, returns `(true, delta)` if `delta` can fit into a single byte, or
- * `(false, 0: Byte)` otherwise.
- */
- protected def byteSizedDelta(x: I#JvmType, y: I#JvmType): (Boolean, Byte)
+ override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = {
+ (new Encoder).asInstanceOf[compression.Encoder[T]]
+ }
- /**
- * Simply computes `x + delta`
- */
- protected def addDelta(x: I#JvmType, delta: Byte): I#JvmType
+ override def supports(columnType: ColumnType[_, _]) = columnType == INT
- class Encoder extends compression.Encoder[I] {
- private var _compressedSize: Int = 0
+ class Encoder extends compression.Encoder[IntegerType.type] {
+ protected var _compressedSize: Int = 0
+ protected var _uncompressedSize: Int = 0
- private var _uncompressedSize: Int = 0
+ override def compressedSize = _compressedSize
+ override def uncompressedSize = _uncompressedSize
- private var prev: I#JvmType = _
+ private var prevValue: Int = _
- private var initial = true
+ override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {
+ val value = row.getInt(ordinal)
+ val delta = value - prevValue
- override def gatherCompressibilityStats(value: I#JvmType, columnType: NativeColumnType[I]) {
- _uncompressedSize += columnType.defaultSize
+ _compressedSize += 1
- if (initial) {
- initial = false
- _compressedSize += 1 + columnType.defaultSize
- } else {
- val (smallEnough, _) = byteSizedDelta(value, prev)
- _compressedSize += (if (smallEnough) 1 else 1 + columnType.defaultSize)
+ // If this is the first integer to be compressed, or the delta is out of byte range, then give
+ // up compressing this integer.
+ if (_uncompressedSize == 0 || delta <= Byte.MinValue || delta > Byte.MaxValue) {
+ _compressedSize += INT.defaultSize
}
- prev = value
+ _uncompressedSize += INT.defaultSize
+ prevValue = value
}
- override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[I]) = {
+ override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = {
to.putInt(typeId)
if (from.hasRemaining) {
- var prev = columnType.extract(from)
+ var prev = from.getInt()
to.put(Byte.MinValue)
- columnType.append(prev, to)
+ to.putInt(prev)
while (from.hasRemaining) {
- val current = columnType.extract(from)
- val (smallEnough, delta) = byteSizedDelta(current, prev)
+ val current = from.getInt()
+ val delta = current - prev
prev = current
- if (smallEnough) {
- to.put(delta)
+ if (Byte.MinValue < delta && delta <= Byte.MaxValue) {
+ to.put(delta.toByte)
} else {
to.put(Byte.MinValue)
- columnType.append(current, to)
+ to.putInt(current)
}
}
}
- to.rewind()
- to
+ to.rewind().asInstanceOf[ByteBuffer]
}
-
- override def uncompressedSize = _uncompressedSize
-
- override def compressedSize = _compressedSize
}
- class Decoder(buffer: ByteBuffer, columnType: NativeColumnType[I])
- extends compression.Decoder[I] {
+ class Decoder(buffer: ByteBuffer, columnType: NativeColumnType[IntegerType.type])
+ extends compression.Decoder[IntegerType.type] {
+
+ private var prev: Int = _
- private var prev: I#JvmType = _
+ override def hasNext: Boolean = buffer.hasRemaining
- override def next() = {
+ override def next(row: MutableRow, ordinal: Int): Unit = {
val delta = buffer.get()
- prev = if (delta > Byte.MinValue) addDelta(prev, delta) else columnType.extract(buffer)
- prev
+ prev = if (delta > Byte.MinValue) prev + delta else buffer.getInt()
+ row.setInt(ordinal, prev)
}
-
- override def hasNext = buffer.hasRemaining
}
}
-private[sql] case object IntDelta extends IntegralDelta[IntegerType.type] {
- override val typeId = 4
+private[sql] case object LongDelta extends CompressionScheme {
+ override def typeId: Int = 5
- override def supports(columnType: ColumnType[_, _]) = columnType == INT
+ override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = {
+ new Decoder(buffer, LONG).asInstanceOf[compression.Decoder[T]]
+ }
+
+ override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = {
+ (new Encoder).asInstanceOf[compression.Encoder[T]]
+ }
- override protected def addDelta(x: Int, delta: Byte) = x + delta
+ override def supports(columnType: ColumnType[_, _]) = columnType == LONG
+
+ class Encoder extends compression.Encoder[LongType.type] {
+ protected var _compressedSize: Int = 0
+ protected var _uncompressedSize: Int = 0
+
+ override def compressedSize = _compressedSize
+ override def uncompressedSize = _uncompressedSize
+
+ private var prevValue: Long = _
+
+ override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {
+ val value = row.getLong(ordinal)
+ val delta = value - prevValue
+
+ _compressedSize += 1
- override protected def byteSizedDelta(x: Int, y: Int): (Boolean, Byte) = {
- val delta = x - y
- if (math.abs(delta) <= Byte.MaxValue) (true, delta.toByte) else (false, 0: Byte)
+ // If this is the first long integer to be compressed, or the delta is out of byte range, then
+ // give up compressing this long integer.
+ if (_uncompressedSize == 0 || delta <= Byte.MinValue || delta > Byte.MaxValue) {
+ _compressedSize += LONG.defaultSize
+ }
+
+ _uncompressedSize += LONG.defaultSize
+ prevValue = value
+ }
+
+ override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = {
+ to.putInt(typeId)
+
+ if (from.hasRemaining) {
+ var prev = from.getLong()
+ to.put(Byte.MinValue)
+ to.putLong(prev)
+
+ while (from.hasRemaining) {
+ val current = from.getLong()
+ val delta = current - prev
+ prev = current
+
+ if (Byte.MinValue < delta && delta <= Byte.MaxValue) {
+ to.put(delta.toByte)
+ } else {
+ to.put(Byte.MinValue)
+ to.putLong(current)
+ }
+ }
+ }
+
+ to.rewind().asInstanceOf[ByteBuffer]
+ }
}
-}
-private[sql] case object LongDelta extends IntegralDelta[LongType.type] {
- override val typeId = 5
+ class Decoder(buffer: ByteBuffer, columnType: NativeColumnType[LongType.type])
+ extends compression.Decoder[LongType.type] {
- override def supports(columnType: ColumnType[_, _]) = columnType == LONG
+ private var prev: Long = _
- override protected def addDelta(x: Long, delta: Byte) = x + delta
+ override def hasNext: Boolean = buffer.hasRemaining
- override protected def byteSizedDelta(x: Long, y: Long): (Boolean, Byte) = {
- val delta = x - y
- if (math.abs(delta) <= Byte.MaxValue) (true, delta.toByte) else (false, 0: Byte)
+ override def next(row: MutableRow, ordinal: Int): Unit = {
+ val delta = buffer.get()
+ prev = if (delta > Byte.MinValue) prev + delta else buffer.getLong()
+ row.setLong(ordinal, prev)
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index cde91ceb68..0cdbb3167c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -35,7 +35,7 @@ class ColumnStatsSuite extends FunSuite {
def testColumnStats[T <: NativeType, U <: ColumnStats](
columnStatsClass: Class[U],
columnType: NativeColumnType[T],
- initialStatistics: Row) {
+ initialStatistics: Row): Unit = {
val columnStatsName = columnStatsClass.getSimpleName
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index 75f653f328..4fb1ecf1d5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -23,6 +23,7 @@ import java.sql.Timestamp
import org.scalatest.FunSuite
import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.columnar.ColumnarTestUtils._
import org.apache.spark.sql.execution.SparkSqlSerializer
@@ -46,10 +47,12 @@ class ColumnTypeSuite extends FunSuite with Logging {
def checkActualSize[T <: DataType, JvmType](
columnType: ColumnType[T, JvmType],
value: JvmType,
- expected: Int) {
+ expected: Int): Unit = {
assertResult(expected, s"Wrong actualSize for $columnType") {
- columnType.actualSize(value)
+ val row = new GenericMutableRow(1)
+ columnType.setField(row, 0, value)
+ columnType.actualSize(row, 0)
}
}
@@ -147,7 +150,7 @@ class ColumnTypeSuite extends FunSuite with Logging {
def testNativeColumnType[T <: NativeType](
columnType: NativeColumnType[T],
putter: (ByteBuffer, T#JvmType) => Unit,
- getter: (ByteBuffer) => T#JvmType) {
+ getter: (ByteBuffer) => T#JvmType): Unit = {
testColumnType[T, T#JvmType](columnType, putter, getter)
}
@@ -155,7 +158,7 @@ class ColumnTypeSuite extends FunSuite with Logging {
def testColumnType[T <: DataType, JvmType](
columnType: ColumnType[T, JvmType],
putter: (ByteBuffer, JvmType) => Unit,
- getter: (ByteBuffer) => JvmType) {
+ getter: (ByteBuffer) => JvmType): Unit = {
val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE)
val seq = (0 until 4).map(_ => makeRandomValue(columnType))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index 0e3c67f5ee..c1278248ef 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.columnar
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{SQLConf, QueryTest, TestData}
+import org.apache.spark.sql.{QueryTest, TestData}
class InMemoryColumnarQuerySuite extends QueryTest {
import org.apache.spark.sql.TestData._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
index 3baa6f8ec0..6c9a9ab6c3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
@@ -45,7 +45,9 @@ class NullableColumnAccessorSuite extends FunSuite {
testNullableColumnAccessor(_)
}
- def testNullableColumnAccessor[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) {
+ def testNullableColumnAccessor[T <: DataType, JvmType](
+ columnType: ColumnType[T, JvmType]): Unit = {
+
val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
val nullRow = makeNullRow(1)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
index a77262534a..f54a21eb4f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
@@ -41,7 +41,9 @@ class NullableColumnBuilderSuite extends FunSuite {
testNullableColumnBuilder(_)
}
- def testNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) {
+ def testNullableColumnBuilder[T <: DataType, JvmType](
+ columnType: ColumnType[T, JvmType]): Unit = {
+
val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
test(s"$typeName column builder: empty column") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
index 5d2fd49591..69e0adbd3e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
@@ -28,7 +28,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
val originalColumnBatchSize = columnBatchSize
val originalInMemoryPartitionPruning = inMemoryPartitionPruning
- override protected def beforeAll() {
+ override protected def beforeAll(): Unit = {
// Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
setConf(SQLConf.COLUMN_BATCH_SIZE, "10")
val rawData = sparkContext.makeRDD(1 to 100, 5).map(IntegerData)
@@ -38,7 +38,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
}
- override protected def afterAll() {
+ override protected def afterAll(): Unit = {
setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString)
setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString)
}
@@ -76,7 +76,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
filter: String,
expectedQueryResult: Seq[Int],
expectedReadPartitions: Int,
- expectedReadBatches: Int) {
+ expectedReadBatches: Int): Unit = {
test(filter) {
val query = sql(s"SELECT * FROM intData WHERE $filter")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
index e01cc8b4d2..d9e488e0ff 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.columnar.compression
import org.scalatest.FunSuite
import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.columnar.{NoopColumnStats, BOOLEAN}
import org.apache.spark.sql.columnar.ColumnarTestUtils._
@@ -72,10 +73,14 @@ class BooleanBitSetSuite extends FunSuite {
buffer.rewind().position(headerSize + 4)
val decoder = BooleanBitSet.decoder(buffer, BOOLEAN)
+ val mutableRow = new GenericMutableRow(1)
if (values.nonEmpty) {
values.foreach {
assert(decoder.hasNext)
- assertResult(_, "Wrong decoded value")(decoder.next())
+ assertResult(_, "Wrong decoded value") {
+ decoder.next(mutableRow, 0)
+ mutableRow.getBoolean(0)
+ }
}
}
assert(!decoder.hasNext)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
index d2969d906c..1cdb909146 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
@@ -21,6 +21,7 @@ import java.nio.ByteBuffer
import org.scalatest.FunSuite
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.types.NativeType
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.columnar.ColumnarTestUtils._
@@ -67,7 +68,7 @@ class DictionaryEncodingSuite extends FunSuite {
val buffer = builder.build()
val headerSize = CompressionScheme.columnHeaderSize(buffer)
// 4 extra bytes for dictionary size
- val dictionarySize = 4 + values.map(columnType.actualSize).sum
+ val dictionarySize = 4 + rows.map(columnType.actualSize(_, 0)).sum
// 2 bytes for each `Short`
val compressedSize = 4 + dictionarySize + 2 * inputSeq.length
// 4 extra bytes for compression scheme type ID
@@ -97,11 +98,15 @@ class DictionaryEncodingSuite extends FunSuite {
buffer.rewind().position(headerSize + 4)
val decoder = DictionaryEncoding.decoder(buffer, columnType)
+ val mutableRow = new GenericMutableRow(1)
if (inputSeq.nonEmpty) {
inputSeq.foreach { i =>
assert(decoder.hasNext)
- assertResult(values(i), "Wrong decoded value")(decoder.next())
+ assertResult(values(i), "Wrong decoded value") {
+ decoder.next(mutableRow, 0)
+ columnType.getField(mutableRow, 0)
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
index 322f447c24..73f31c0233 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
@@ -31,7 +31,7 @@ class IntegralDeltaSuite extends FunSuite {
def testIntegralDelta[I <: IntegralType](
columnStats: ColumnStats,
columnType: NativeColumnType[I],
- scheme: IntegralDelta[I]) {
+ scheme: CompressionScheme) {
def skeleton(input: Seq[I#JvmType]) {
// -------------
@@ -96,10 +96,15 @@ class IntegralDeltaSuite extends FunSuite {
buffer.rewind().position(headerSize + 4)
val decoder = scheme.decoder(buffer, columnType)
+ val mutableRow = new GenericMutableRow(1)
+
if (input.nonEmpty) {
input.foreach{
assert(decoder.hasNext)
- assertResult(_, "Wrong decoded value")(decoder.next())
+ assertResult(_, "Wrong decoded value") {
+ decoder.next(mutableRow, 0)
+ columnType.getField(mutableRow, 0)
+ }
}
}
assert(!decoder.hasNext)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
index 218c09ac26..4ce2552112 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.columnar.compression
import org.scalatest.FunSuite
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.types.NativeType
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.columnar.ColumnarTestUtils._
@@ -57,7 +58,7 @@ class RunLengthEncodingSuite extends FunSuite {
// Compression scheme ID + compressed contents
val compressedSize = 4 + inputRuns.map { case (index, _) =>
// 4 extra bytes each run for run length
- columnType.actualSize(values(index)) + 4
+ columnType.actualSize(rows(index), 0) + 4
}.sum
// 4 extra bytes for compression scheme type ID
@@ -80,11 +81,15 @@ class RunLengthEncodingSuite extends FunSuite {
buffer.rewind().position(headerSize + 4)
val decoder = RunLengthEncoding.decoder(buffer, columnType)
+ val mutableRow = new GenericMutableRow(1)
if (inputSeq.nonEmpty) {
inputSeq.foreach { i =>
assert(decoder.hasNext)
- assertResult(values(i), "Wrong decoded value")(decoder.next())
+ assertResult(values(i), "Wrong decoded value") {
+ decoder.next(mutableRow, 0)
+ columnType.getField(mutableRow, 0)
+ }
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
index 329f80cad4..84fafcde63 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
@@ -25,16 +25,14 @@ import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table =>
import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc}
import org.apache.hadoop.hive.serde2.Deserializer
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector
-
+import org.apache.hadoop.hive.serde2.objectinspector.primitive._
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf}
import org.apache.spark.SerializableWritable
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD}
-
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Row, GenericMutableRow, Literal, Cast}
-import org.apache.spark.sql.catalyst.types.DataType
+import org.apache.spark.sql.catalyst.expressions._
/**
* A trait for subclasses that handle table scans.
@@ -108,12 +106,12 @@ class HadoopTableReader(
val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc)
val attrsWithIndex = attributes.zipWithIndex
- val mutableRow = new GenericMutableRow(attrsWithIndex.length)
+ val mutableRow = new SpecificMutableRow(attributes.map(_.dataType))
+
val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter =>
val hconf = broadcastedHiveConf.value.value
val deserializer = deserializerClass.newInstance()
deserializer.initialize(hconf, tableDesc.getProperties)
-
HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow)
}
@@ -164,33 +162,32 @@ class HadoopTableReader(
val tableDesc = relation.tableDesc
val broadcastedHiveConf = _broadcastedHiveConf
val localDeserializer = partDeserializer
- val mutableRow = new GenericMutableRow(attributes.length)
-
- // split the attributes (output schema) into 2 categories:
- // (partition keys, ordinal), (normal attributes, ordinal), the ordinal mean the
- // index of the attribute in the output Row.
- val (partitionKeys, attrs) = attributes.zipWithIndex.partition(attr => {
- relation.partitionKeys.indexOf(attr._1) >= 0
- })
-
- def fillPartitionKeys(parts: Array[String], row: GenericMutableRow) = {
- partitionKeys.foreach { case (attr, ordinal) =>
- // get partition key ordinal for a given attribute
- val partOridinal = relation.partitionKeys.indexOf(attr)
- row(ordinal) = Cast(Literal(parts(partOridinal)), attr.dataType).eval(null)
+ val mutableRow = new SpecificMutableRow(attributes.map(_.dataType))
+
+ // Splits all attributes into two groups, partition key attributes and those that are not.
+ // Attached indices indicate the position of each attribute in the output schema.
+ val (partitionKeyAttrs, nonPartitionKeyAttrs) =
+ attributes.zipWithIndex.partition { case (attr, _) =>
+ relation.partitionKeys.contains(attr)
+ }
+
+ def fillPartitionKeys(rawPartValues: Array[String], row: MutableRow) = {
+ partitionKeyAttrs.foreach { case (attr, ordinal) =>
+ val partOrdinal = relation.partitionKeys.indexOf(attr)
+ row(ordinal) = Cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null)
}
}
- // fill the partition key for the given MutableRow Object
+
+ // Fill all partition keys to the given MutableRow object
fillPartitionKeys(partValues, mutableRow)
- val hivePartitionRDD = createHadoopRdd(tableDesc, inputPathStr, ifc)
- hivePartitionRDD.mapPartitions { iter =>
+ createHadoopRdd(tableDesc, inputPathStr, ifc).mapPartitions { iter =>
val hconf = broadcastedHiveConf.value.value
val deserializer = localDeserializer.newInstance()
deserializer.initialize(hconf, partProps)
- // fill the non partition key attributes
- HadoopTableReader.fillObject(iter, deserializer, attrs, mutableRow)
+ // fill the non partition key attributes
+ HadoopTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, mutableRow)
}
}.toSeq
@@ -257,38 +254,64 @@ private[hive] object HadoopTableReader extends HiveInspectors {
}
/**
- * Transform the raw data(Writable object) into the Row object for an iterable input
- * @param iter Iterable input which represented as Writable object
- * @param deserializer Deserializer associated with the input writable object
- * @param attrs Represents the row attribute names and its zero-based position in the MutableRow
- * @param row reusable MutableRow object
- *
- * @return Iterable Row object that transformed from the given iterable input.
+ * Transform all given raw `Writable`s into `Row`s.
+ *
+ * @param iterator Iterator of all `Writable`s to be transformed
+ * @param deserializer The `Deserializer` associated with the input `Writable`
+ * @param nonPartitionKeyAttrs Attributes that should be filled together with their corresponding
+ * positions in the output schema
+ * @param mutableRow A reusable `MutableRow` that should be filled
+ * @return An `Iterator[Row]` transformed from `iterator`
*/
def fillObject(
- iter: Iterator[Writable],
+ iterator: Iterator[Writable],
deserializer: Deserializer,
- attrs: Seq[(Attribute, Int)],
- row: GenericMutableRow): Iterator[Row] = {
+ nonPartitionKeyAttrs: Seq[(Attribute, Int)],
+ mutableRow: MutableRow): Iterator[Row] = {
+
val soi = deserializer.getObjectInspector().asInstanceOf[StructObjectInspector]
- // get the field references according to the attributes(output of the reader) required
- val fieldRefs = attrs.map { case (attr, idx) => (soi.getStructFieldRef(attr.name), idx) }
+ val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { case (attr, ordinal) =>
+ soi.getStructFieldRef(attr.name) -> ordinal
+ }.unzip
+
+ // Builds specific unwrappers ahead of time according to object inspector types to avoid pattern
+ // matching and branching costs per row.
+ val unwrappers: Seq[(Any, MutableRow, Int) => Unit] = fieldRefs.map {
+ _.getFieldObjectInspector match {
+ case oi: BooleanObjectInspector =>
+ (value: Any, row: MutableRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value))
+ case oi: ByteObjectInspector =>
+ (value: Any, row: MutableRow, ordinal: Int) => row.setByte(ordinal, oi.get(value))
+ case oi: ShortObjectInspector =>
+ (value: Any, row: MutableRow, ordinal: Int) => row.setShort(ordinal, oi.get(value))
+ case oi: IntObjectInspector =>
+ (value: Any, row: MutableRow, ordinal: Int) => row.setInt(ordinal, oi.get(value))
+ case oi: LongObjectInspector =>
+ (value: Any, row: MutableRow, ordinal: Int) => row.setLong(ordinal, oi.get(value))
+ case oi: FloatObjectInspector =>
+ (value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value))
+ case oi: DoubleObjectInspector =>
+ (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value))
+ case oi =>
+ (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrapData(value, oi)
+ }
+ }
// Map each tuple to a row object
- iter.map { value =>
+ iterator.map { value =>
val raw = deserializer.deserialize(value)
- var idx = 0;
- while (idx < fieldRefs.length) {
- val fieldRef = fieldRefs(idx)._1
- val fieldIdx = fieldRefs(idx)._2
- val fieldValue = soi.getStructFieldData(raw, fieldRef)
-
- row(fieldIdx) = unwrapData(fieldValue, fieldRef.getFieldObjectInspector())
-
- idx += 1
+ var i = 0
+ while (i < fieldRefs.length) {
+ val fieldValue = soi.getStructFieldData(raw, fieldRefs(i))
+ if (fieldValue == null) {
+ mutableRow.setNullAt(fieldOrdinals(i))
+ } else {
+ unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i))
+ }
+ i += 1
}
- row: Row
+ mutableRow: Row
}
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 6bf8d18a5c..8c8a8b124a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -295,8 +295,16 @@ class HiveQuerySuite extends HiveComparisonTest {
"SELECT (CASE WHEN key > 2 THEN 3 WHEN 2 > key THEN 2 ELSE 0 END) FROM src WHERE key < 15")
test("implement identity function using case statement") {
- val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src").collect().toSet
- val expected = sql("SELECT key FROM src").collect().toSet
+ val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src")
+ .map { case Row(i: Int) => i }
+ .collect()
+ .toSet
+
+ val expected = sql("SELECT key FROM src")
+ .map { case Row(i: Int) => i }
+ .collect()
+ .toSet
+
assert(actual === expected)
}
@@ -559,9 +567,9 @@ class HiveQuerySuite extends HiveComparisonTest {
val testVal = "test.val.0"
val nonexistentKey = "nonexistent"
val KV = "([^=]+)=([^=]*)".r
- def collectResults(rdd: SchemaRDD): Set[(String, String)] =
- rdd.collect().map {
- case Row(key: String, value: String) => key -> value
+ def collectResults(rdd: SchemaRDD): Set[(String, String)] =
+ rdd.collect().map {
+ case Row(key: String, value: String) => key -> value
case Row(KV(key, value)) => key -> value
}.toSet
clear()