aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorCheng Lian <lian.cs.zju@gmail.com>2014-04-02 12:47:22 -0700
committerPatrick Wendell <pwendell@gmail.com>2014-04-02 12:47:22 -0700
commit1faa57971192226837bea32eb29eae5bfb425a7e (patch)
treebfbe41e2007801ebd6f62f7b6d51d8a07d51ecd1 /sql
parent78236334e4ca7518b6d7d9b38464dbbda854a777 (diff)
downloadspark-1faa57971192226837bea32eb29eae5bfb425a7e.tar.gz
spark-1faa57971192226837bea32eb29eae5bfb425a7e.tar.bz2
spark-1faa57971192226837bea32eb29eae5bfb425a7e.zip
[SPARK-1371][WIP] Compression support for Spark SQL in-memory columnar storage
JIRA issue: [SPARK-1373](https://issues.apache.org/jira/browse/SPARK-1373) (Although tagged as WIP, this PR is structurally complete. The only things left unimplemented are 3 more compression algorithms: `BooleanBitSet`, `IntDelta` and `LongDelta`, which are trivial to add later in this or another separate PR.) This PR contains compression support for Spark SQL in-memory columnar storage. Main interfaces include: * `CompressionScheme` Each `CompressionScheme` represents a concrete compression algorithm, which basically consists of an `Encoder` for compression and a `Decoder` for decompression. Algorithms implemented include: * `RunLengthEncoding` * `DictionaryEncoding` Algorithms to be implemented include: * `BooleanBitSet` * `IntDelta` * `LongDelta` * `CompressibleColumnBuilder` A stackable `ColumnBuilder` trait used to build byte buffers for compressible columns. A best `CompressionScheme` that exhibits lowest compression ratio is chosen for each column according to statistical information gathered while elements are appended into the `ColumnBuilder`. However, if no `CompressionScheme` can achieve a compression ratio better than 80%, no compression will be done for this column to save CPU time. Memory layout of the final byte buffer is showed below: ``` .--------------------------- Column type ID (4 bytes) | .----------------------- Null count N (4 bytes) | | .------------------- Null positions (4 x N bytes, empty if null count is zero) | | | .------------- Compression scheme ID (4 bytes) | | | | .--------- Compressed non-null elements V V V V V +---+---+-----+---+---------+ | | | ... | | ... ... | +---+---+-----+---+---------+ \-----------/ \-----------/ header body ``` * `CompressibleColumnAccessor` A stackable `ColumnAccessor` trait used to iterate (possibly) compressed data column. * `ColumnStats` Used to collect statistical information while loading data into in-memory columnar table. Optimizations like partition pruning rely on this information. Strictly speaking, `ColumnStats` related code is not part of the compression support. It's contained in this PR to ensure and validate the row-based API design (which is used to avoid boxing/unboxing cost whenever possible). A major refactoring change since PR #205 is: * Refactored all getter/setter methods for primitive types in various places into `ColumnType` classes to remove duplicated code. Author: Cheng Lian <lian.cs.zju@gmail.com> Closes #285 from liancheng/memColumnarCompression and squashes the following commits: ed71bbd [Cheng Lian] Addressed all PR comments by @marmbrus d3a4fa9 [Cheng Lian] Removed Ordering[T] in ColumnStats for better performance 5034453 [Cheng Lian] Bug fix, more tests, and more refactoring c298b76 [Cheng Lian] Test suites refactored 2780d6a [Cheng Lian] [WIP] in-memory columnar compression support 211331c [Cheng Lian] WIP: in-memory columnar compression support 85cc59b [Cheng Lian] Refactored ColumnAccessors & ColumnBuilders to remove duplicate code
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala103
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala125
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala360
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala87
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala (renamed from sql/core/src/main/scala/org/apache/spark/sql/columnar/inMemoryColumnarOperators.scala)7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala29
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala36
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala95
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala94
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala288
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala61
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala216
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarQuerySuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestData.scala55
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala100
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala43
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala61
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala113
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala130
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala43
21 files changed, 1644 insertions, 408 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
index e0c98ecdf8..ffd4894b52 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
@@ -21,7 +21,7 @@ import java.nio.{ByteOrder, ByteBuffer}
import org.apache.spark.sql.catalyst.types.{BinaryType, NativeType, DataType}
import org.apache.spark.sql.catalyst.expressions.MutableRow
-import org.apache.spark.sql.execution.SparkSqlSerializer
+import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor
/**
* An `Iterator` like trait used to extract values from columnar byte buffer. When a value is
@@ -41,121 +41,66 @@ private[sql] trait ColumnAccessor {
protected def underlyingBuffer: ByteBuffer
}
-private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType](buffer: ByteBuffer)
+private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType](
+ protected val buffer: ByteBuffer,
+ protected val columnType: ColumnType[T, JvmType])
extends ColumnAccessor {
protected def initialize() {}
- def columnType: ColumnType[T, JvmType]
-
def hasNext = buffer.hasRemaining
def extractTo(row: MutableRow, ordinal: Int) {
- doExtractTo(row, ordinal)
+ columnType.setField(row, ordinal, extractSingle(buffer))
}
- protected def doExtractTo(row: MutableRow, ordinal: Int)
+ def extractSingle(buffer: ByteBuffer): JvmType = columnType.extract(buffer)
protected def underlyingBuffer = buffer
}
private[sql] abstract class NativeColumnAccessor[T <: NativeType](
- buffer: ByteBuffer,
- val columnType: NativeColumnType[T])
- extends BasicColumnAccessor[T, T#JvmType](buffer)
+ override protected val buffer: ByteBuffer,
+ override protected val columnType: NativeColumnType[T])
+ extends BasicColumnAccessor(buffer, columnType)
with NullableColumnAccessor
+ with CompressibleColumnAccessor[T]
private[sql] class BooleanColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, BOOLEAN) {
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row.setBoolean(ordinal, columnType.extract(buffer))
- }
-}
+ extends NativeColumnAccessor(buffer, BOOLEAN)
private[sql] class IntColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, INT) {
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row.setInt(ordinal, columnType.extract(buffer))
- }
-}
+ extends NativeColumnAccessor(buffer, INT)
private[sql] class ShortColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, SHORT) {
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row.setShort(ordinal, columnType.extract(buffer))
- }
-}
+ extends NativeColumnAccessor(buffer, SHORT)
private[sql] class LongColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, LONG) {
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row.setLong(ordinal, columnType.extract(buffer))
- }
-}
+ extends NativeColumnAccessor(buffer, LONG)
private[sql] class ByteColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, BYTE) {
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row.setByte(ordinal, columnType.extract(buffer))
- }
-}
+ extends NativeColumnAccessor(buffer, BYTE)
private[sql] class DoubleColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, DOUBLE) {
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row.setDouble(ordinal, columnType.extract(buffer))
- }
-}
+ extends NativeColumnAccessor(buffer, DOUBLE)
private[sql] class FloatColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, FLOAT) {
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row.setFloat(ordinal, columnType.extract(buffer))
- }
-}
+ extends NativeColumnAccessor(buffer, FLOAT)
private[sql] class StringColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, STRING) {
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row.setString(ordinal, columnType.extract(buffer))
- }
-}
+ extends NativeColumnAccessor(buffer, STRING)
private[sql] class BinaryColumnAccessor(buffer: ByteBuffer)
- extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer)
- with NullableColumnAccessor {
-
- def columnType = BINARY
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row(ordinal) = columnType.extract(buffer)
- }
-}
+ extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer, BINARY)
+ with NullableColumnAccessor
private[sql] class GenericColumnAccessor(buffer: ByteBuffer)
- extends BasicColumnAccessor[DataType, Array[Byte]](buffer)
- with NullableColumnAccessor {
-
- def columnType = GENERIC
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- val serialized = columnType.extract(buffer)
- row(ordinal) = SparkSqlSerializer.deserialize[Any](serialized)
- }
-}
+ extends BasicColumnAccessor[DataType, Array[Byte]](buffer, GENERIC)
+ with NullableColumnAccessor
private[sql] object ColumnAccessor {
- def apply(b: ByteBuffer): ColumnAccessor = {
- // The first 4 bytes in the buffer indicates the column type.
- val buffer = b.duplicate().order(ByteOrder.nativeOrder())
+ def apply(buffer: ByteBuffer): ColumnAccessor = {
+ // The first 4 bytes in the buffer indicate the column type.
val columnTypeId = buffer.getInt()
columnTypeId match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
index 3e622adfd3..048ee66bff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
@@ -22,7 +22,7 @@ import java.nio.{ByteBuffer, ByteOrder}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.columnar.ColumnBuilder._
-import org.apache.spark.sql.execution.SparkSqlSerializer
+import org.apache.spark.sql.columnar.compression.{AllCompressionSchemes, CompressibleColumnBuilder}
private[sql] trait ColumnBuilder {
/**
@@ -30,37 +30,44 @@ private[sql] trait ColumnBuilder {
*/
def initialize(initialSize: Int, columnName: String = "")
+ /**
+ * Appends `row(ordinal)` to the column builder.
+ */
def appendFrom(row: Row, ordinal: Int)
+ /**
+ * Column statistics information
+ */
+ def columnStats: ColumnStats[_, _]
+
+ /**
+ * Returns the final columnar byte buffer.
+ */
def build(): ByteBuffer
}
-private[sql] abstract class BasicColumnBuilder[T <: DataType, JvmType] extends ColumnBuilder {
+private[sql] class BasicColumnBuilder[T <: DataType, JvmType](
+ val columnStats: ColumnStats[T, JvmType],
+ val columnType: ColumnType[T, JvmType])
+ extends ColumnBuilder {
- private var columnName: String = _
- protected var buffer: ByteBuffer = _
+ protected var columnName: String = _
- def columnType: ColumnType[T, JvmType]
+ protected var buffer: ByteBuffer = _
override def initialize(initialSize: Int, columnName: String = "") = {
val size = if (initialSize == 0) DEFAULT_INITIAL_BUFFER_SIZE else initialSize
this.columnName = columnName
- buffer = ByteBuffer.allocate(4 + 4 + size * columnType.defaultSize)
+
+ // Reserves 4 bytes for column type ID
+ buffer = ByteBuffer.allocate(4 + size * columnType.defaultSize)
buffer.order(ByteOrder.nativeOrder()).putInt(columnType.typeId)
}
- // Have to give a concrete implementation to make mixin possible
override def appendFrom(row: Row, ordinal: Int) {
- doAppendFrom(row, ordinal)
- }
-
- // Concrete `ColumnBuilder`s can override this method to append values
- protected def doAppendFrom(row: Row, ordinal: Int)
-
- // Helper method to append primitive values (to avoid boxing cost)
- protected def appendValue(v: JvmType) {
- buffer = ensureFreeSpace(buffer, columnType.actualSize(v))
- columnType.append(v, buffer)
+ val field = columnType.getField(row, ordinal)
+ buffer = ensureFreeSpace(buffer, columnType.actualSize(field))
+ columnType.append(field, buffer)
}
override def build() = {
@@ -69,83 +76,39 @@ private[sql] abstract class BasicColumnBuilder[T <: DataType, JvmType] extends C
}
}
-private[sql] abstract class NativeColumnBuilder[T <: NativeType](
- val columnType: NativeColumnType[T])
- extends BasicColumnBuilder[T, T#JvmType]
+private[sql] abstract class ComplexColumnBuilder[T <: DataType, JvmType](
+ columnType: ColumnType[T, JvmType])
+ extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], columnType)
with NullableColumnBuilder
-private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(BOOLEAN) {
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row.getBoolean(ordinal))
- }
-}
-
-private[sql] class IntColumnBuilder extends NativeColumnBuilder(INT) {
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row.getInt(ordinal))
- }
-}
+private[sql] abstract class NativeColumnBuilder[T <: NativeType](
+ override val columnStats: NativeColumnStats[T],
+ override val columnType: NativeColumnType[T])
+ extends BasicColumnBuilder[T, T#JvmType](columnStats, columnType)
+ with NullableColumnBuilder
+ with AllCompressionSchemes
+ with CompressibleColumnBuilder[T]
-private[sql] class ShortColumnBuilder extends NativeColumnBuilder(SHORT) {
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row.getShort(ordinal))
- }
-}
+private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN)
-private[sql] class LongColumnBuilder extends NativeColumnBuilder(LONG) {
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row.getLong(ordinal))
- }
-}
+private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT)
-private[sql] class ByteColumnBuilder extends NativeColumnBuilder(BYTE) {
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row.getByte(ordinal))
- }
-}
+private[sql] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT)
-private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(DOUBLE) {
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row.getDouble(ordinal))
- }
-}
+private[sql] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG)
-private[sql] class FloatColumnBuilder extends NativeColumnBuilder(FLOAT) {
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row.getFloat(ordinal))
- }
-}
+private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE)
-private[sql] class StringColumnBuilder extends NativeColumnBuilder(STRING) {
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row.getString(ordinal))
- }
-}
+private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE)
-private[sql] class BinaryColumnBuilder
- extends BasicColumnBuilder[BinaryType.type, Array[Byte]]
- with NullableColumnBuilder {
+private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT)
- def columnType = BINARY
+private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING)
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row(ordinal).asInstanceOf[Array[Byte]])
- }
-}
+private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(BINARY)
// TODO (lian) Add support for array, struct and map
-private[sql] class GenericColumnBuilder
- extends BasicColumnBuilder[DataType, Array[Byte]]
- with NullableColumnBuilder {
-
- def columnType = GENERIC
-
- override def doAppendFrom(row: Row, ordinal: Int) {
- val serialized = SparkSqlSerializer.serialize(row(ordinal))
- buffer = ColumnBuilder.ensureFreeSpace(buffer, columnType.actualSize(serialized))
- columnType.append(serialized, buffer)
- }
-}
+private[sql] class GenericColumnBuilder extends ComplexColumnBuilder(GENERIC)
private[sql] object ColumnBuilder {
val DEFAULT_INITIAL_BUFFER_SIZE = 10 * 1024 * 104
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
new file mode 100644
index 0000000000..30c6bdc791
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -0,0 +1,360 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.types._
+
+private[sql] sealed abstract class ColumnStats[T <: DataType, JvmType] extends Serializable {
+ /**
+ * Closed lower bound of this column.
+ */
+ def lowerBound: JvmType
+
+ /**
+ * Closed upper bound of this column.
+ */
+ def upperBound: JvmType
+
+ /**
+ * Gathers statistics information from `row(ordinal)`.
+ */
+ def gatherStats(row: Row, ordinal: Int)
+
+ /**
+ * Returns `true` if `lower <= row(ordinal) <= upper`.
+ */
+ def contains(row: Row, ordinal: Int): Boolean
+
+ /**
+ * Returns `true` if `row(ordinal) < upper` holds.
+ */
+ def isAbove(row: Row, ordinal: Int): Boolean
+
+ /**
+ * Returns `true` if `lower < row(ordinal)` holds.
+ */
+ def isBelow(row: Row, ordinal: Int): Boolean
+
+ /**
+ * Returns `true` if `row(ordinal) <= upper` holds.
+ */
+ def isAtOrAbove(row: Row, ordinal: Int): Boolean
+
+ /**
+ * Returns `true` if `lower <= row(ordinal)` holds.
+ */
+ def isAtOrBelow(row: Row, ordinal: Int): Boolean
+}
+
+private[sql] sealed abstract class NativeColumnStats[T <: NativeType]
+ extends ColumnStats[T, T#JvmType] {
+
+ type JvmType = T#JvmType
+
+ protected var (_lower, _upper) = initialBounds
+
+ def initialBounds: (JvmType, JvmType)
+
+ protected def columnType: NativeColumnType[T]
+
+ override def lowerBound: T#JvmType = _lower
+
+ override def upperBound: T#JvmType = _upper
+
+ override def isAtOrAbove(row: Row, ordinal: Int) = {
+ contains(row, ordinal) || isAbove(row, ordinal)
+ }
+
+ override def isAtOrBelow(row: Row, ordinal: Int) = {
+ contains(row, ordinal) || isBelow(row, ordinal)
+ }
+}
+
+private[sql] class NoopColumnStats[T <: DataType, JvmType] extends ColumnStats[T, JvmType] {
+ override def isAtOrBelow(row: Row, ordinal: Int) = true
+
+ override def isAtOrAbove(row: Row, ordinal: Int) = true
+
+ override def isBelow(row: Row, ordinal: Int) = true
+
+ override def isAbove(row: Row, ordinal: Int) = true
+
+ override def contains(row: Row, ordinal: Int) = true
+
+ override def gatherStats(row: Row, ordinal: Int) {}
+
+ override def upperBound = null.asInstanceOf[JvmType]
+
+ override def lowerBound = null.asInstanceOf[JvmType]
+}
+
+private[sql] abstract class BasicColumnStats[T <: NativeType](
+ protected val columnType: NativeColumnType[T])
+ extends NativeColumnStats[T]
+
+private[sql] class BooleanColumnStats extends BasicColumnStats(BOOLEAN) {
+ override def initialBounds = (true, false)
+
+ override def isBelow(row: Row, ordinal: Int) = {
+ lowerBound < columnType.getField(row, ordinal)
+ }
+
+ override def isAbove(row: Row, ordinal: Int) = {
+ columnType.getField(row, ordinal) < upperBound
+ }
+
+ override def contains(row: Row, ordinal: Int) = {
+ val field = columnType.getField(row, ordinal)
+ lowerBound <= field && field <= upperBound
+ }
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+ if (field > upperBound) _upper = field
+ if (field < lowerBound) _lower = field
+ }
+}
+
+private[sql] class ByteColumnStats extends BasicColumnStats(BYTE) {
+ override def initialBounds = (Byte.MaxValue, Byte.MinValue)
+
+ override def isBelow(row: Row, ordinal: Int) = {
+ lowerBound < columnType.getField(row, ordinal)
+ }
+
+ override def isAbove(row: Row, ordinal: Int) = {
+ columnType.getField(row, ordinal) < upperBound
+ }
+
+ override def contains(row: Row, ordinal: Int) = {
+ val field = columnType.getField(row, ordinal)
+ lowerBound <= field && field <= upperBound
+ }
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+ if (field > upperBound) _upper = field
+ if (field < lowerBound) _lower = field
+ }
+}
+
+private[sql] class ShortColumnStats extends BasicColumnStats(SHORT) {
+ override def initialBounds = (Short.MaxValue, Short.MinValue)
+
+ override def isBelow(row: Row, ordinal: Int) = {
+ lowerBound < columnType.getField(row, ordinal)
+ }
+
+ override def isAbove(row: Row, ordinal: Int) = {
+ columnType.getField(row, ordinal) < upperBound
+ }
+
+ override def contains(row: Row, ordinal: Int) = {
+ val field = columnType.getField(row, ordinal)
+ lowerBound <= field && field <= upperBound
+ }
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+ if (field > upperBound) _upper = field
+ if (field < lowerBound) _lower = field
+ }
+}
+
+private[sql] class LongColumnStats extends BasicColumnStats(LONG) {
+ override def initialBounds = (Long.MaxValue, Long.MinValue)
+
+ override def isBelow(row: Row, ordinal: Int) = {
+ lowerBound < columnType.getField(row, ordinal)
+ }
+
+ override def isAbove(row: Row, ordinal: Int) = {
+ columnType.getField(row, ordinal) < upperBound
+ }
+
+ override def contains(row: Row, ordinal: Int) = {
+ val field = columnType.getField(row, ordinal)
+ lowerBound <= field && field <= upperBound
+ }
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+ if (field > upperBound) _upper = field
+ if (field < lowerBound) _lower = field
+ }
+}
+
+private[sql] class DoubleColumnStats extends BasicColumnStats(DOUBLE) {
+ override def initialBounds = (Double.MaxValue, Double.MinValue)
+
+ override def isBelow(row: Row, ordinal: Int) = {
+ lowerBound < columnType.getField(row, ordinal)
+ }
+
+ override def isAbove(row: Row, ordinal: Int) = {
+ columnType.getField(row, ordinal) < upperBound
+ }
+
+ override def contains(row: Row, ordinal: Int) = {
+ val field = columnType.getField(row, ordinal)
+ lowerBound <= field && field <= upperBound
+ }
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+ if (field > upperBound) _upper = field
+ if (field < lowerBound) _lower = field
+ }
+}
+
+private[sql] class FloatColumnStats extends BasicColumnStats(FLOAT) {
+ override def initialBounds = (Float.MaxValue, Float.MinValue)
+
+ override def isBelow(row: Row, ordinal: Int) = {
+ lowerBound < columnType.getField(row, ordinal)
+ }
+
+ override def isAbove(row: Row, ordinal: Int) = {
+ columnType.getField(row, ordinal) < upperBound
+ }
+
+ override def contains(row: Row, ordinal: Int) = {
+ val field = columnType.getField(row, ordinal)
+ lowerBound <= field && field <= upperBound
+ }
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+ if (field > upperBound) _upper = field
+ if (field < lowerBound) _lower = field
+ }
+}
+
+private[sql] object IntColumnStats {
+ val UNINITIALIZED = 0
+ val INITIALIZED = 1
+ val ASCENDING = 2
+ val DESCENDING = 3
+ val UNORDERED = 4
+}
+
+/**
+ * Statistical information for `Int` columns. More information is collected since `Int` is
+ * frequently used. Extra information include:
+ *
+ * - Ordering state (ascending/descending/unordered), may be used to decide whether binary search
+ * is applicable when searching elements.
+ * - Maximum delta between adjacent elements, may be used to guide the `IntDelta` compression
+ * scheme.
+ *
+ * (This two kinds of information are not used anywhere yet and might be removed later.)
+ */
+private[sql] class IntColumnStats extends BasicColumnStats(INT) {
+ import IntColumnStats._
+
+ private var orderedState = UNINITIALIZED
+ private var lastValue: Int = _
+ private var _maxDelta: Int = _
+
+ def isAscending = orderedState != DESCENDING && orderedState != UNORDERED
+ def isDescending = orderedState != ASCENDING && orderedState != UNORDERED
+ def isOrdered = isAscending || isDescending
+ def maxDelta = _maxDelta
+
+ override def initialBounds = (Int.MaxValue, Int.MinValue)
+
+ override def isBelow(row: Row, ordinal: Int) = {
+ lowerBound < columnType.getField(row, ordinal)
+ }
+
+ override def isAbove(row: Row, ordinal: Int) = {
+ columnType.getField(row, ordinal) < upperBound
+ }
+
+ override def contains(row: Row, ordinal: Int) = {
+ val field = columnType.getField(row, ordinal)
+ lowerBound <= field && field <= upperBound
+ }
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+
+ if (field > upperBound) _upper = field
+ if (field < lowerBound) _lower = field
+
+ orderedState = orderedState match {
+ case UNINITIALIZED =>
+ lastValue = field
+ INITIALIZED
+
+ case INITIALIZED =>
+ // If all the integers in the column are the same, ordered state is set to Ascending.
+ // TODO (lian) Confirm whether this is the standard behaviour.
+ val nextState = if (field >= lastValue) ASCENDING else DESCENDING
+ _maxDelta = math.abs(field - lastValue)
+ lastValue = field
+ nextState
+
+ case ASCENDING if field < lastValue =>
+ UNORDERED
+
+ case DESCENDING if field > lastValue =>
+ UNORDERED
+
+ case state @ (ASCENDING | DESCENDING) =>
+ _maxDelta = _maxDelta.max(field - lastValue)
+ lastValue = field
+ state
+
+ case _ =>
+ orderedState
+ }
+ }
+}
+
+private[sql] class StringColumnStats extends BasicColumnStats(STRING) {
+ override def initialBounds = (null, null)
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+ if ((upperBound eq null) || field.compareTo(upperBound) > 0) _upper = field
+ if ((lowerBound eq null) || field.compareTo(lowerBound) < 0) _lower = field
+ }
+
+ override def contains(row: Row, ordinal: Int) = {
+ !(upperBound eq null) && {
+ val field = columnType.getField(row, ordinal)
+ lowerBound.compareTo(field) <= 0 && field.compareTo(upperBound) <= 0
+ }
+ }
+
+ override def isAbove(row: Row, ordinal: Int) = {
+ !(upperBound eq null) && {
+ val field = columnType.getField(row, ordinal)
+ field.compareTo(upperBound) < 0
+ }
+ }
+
+ override def isBelow(row: Row, ordinal: Int) = {
+ !(lowerBound eq null) && {
+ val field = columnType.getField(row, ordinal)
+ lowerBound.compareTo(field) < 0
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index a452b86f0c..5be76890af 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -19,7 +19,12 @@ package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.MutableRow
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.execution.SparkSqlSerializer
/**
* An abstract class that represents type of a column. Used to append/extract Java objects into/from
@@ -51,9 +56,23 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType](
def actualSize(v: JvmType): Int = defaultSize
/**
+ * Returns `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs
+ * whenever possible.
+ */
+ def getField(row: Row, ordinal: Int): JvmType
+
+ /**
+ * Sets `row(ordinal)` to `field`. Subclasses should override this method to avoid boxing/unboxing
+ * costs whenever possible.
+ */
+ def setField(row: MutableRow, ordinal: Int, value: JvmType)
+
+ /**
* Creates a duplicated copy of the value.
*/
def clone(v: JvmType): JvmType = v
+
+ override def toString = getClass.getSimpleName.stripSuffix("$")
}
private[sql] abstract class NativeColumnType[T <: NativeType](
@@ -65,7 +84,7 @@ private[sql] abstract class NativeColumnType[T <: NativeType](
/**
* Scala TypeTag. Can be used to create primitive arrays and hash tables.
*/
- def scalaTag = dataType.tag
+ def scalaTag: TypeTag[dataType.JvmType] = dataType.tag
}
private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) {
@@ -76,6 +95,12 @@ private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) {
def extract(buffer: ByteBuffer) = {
buffer.getInt()
}
+
+ override def setField(row: MutableRow, ordinal: Int, value: Int) {
+ row.setInt(ordinal, value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = row.getInt(ordinal)
}
private[sql] object LONG extends NativeColumnType(LongType, 1, 8) {
@@ -86,6 +111,12 @@ private[sql] object LONG extends NativeColumnType(LongType, 1, 8) {
override def extract(buffer: ByteBuffer) = {
buffer.getLong()
}
+
+ override def setField(row: MutableRow, ordinal: Int, value: Long) {
+ row.setLong(ordinal, value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = row.getLong(ordinal)
}
private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) {
@@ -96,6 +127,12 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) {
override def extract(buffer: ByteBuffer) = {
buffer.getFloat()
}
+
+ override def setField(row: MutableRow, ordinal: Int, value: Float) {
+ row.setFloat(ordinal, value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = row.getFloat(ordinal)
}
private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) {
@@ -106,6 +143,12 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) {
override def extract(buffer: ByteBuffer) = {
buffer.getDouble()
}
+
+ override def setField(row: MutableRow, ordinal: Int, value: Double) {
+ row.setDouble(ordinal, value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = row.getDouble(ordinal)
}
private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) {
@@ -116,6 +159,12 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) {
override def extract(buffer: ByteBuffer) = {
if (buffer.get() == 1) true else false
}
+
+ override def setField(row: MutableRow, ordinal: Int, value: Boolean) {
+ row.setBoolean(ordinal, value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = row.getBoolean(ordinal)
}
private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) {
@@ -126,6 +175,12 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) {
override def extract(buffer: ByteBuffer) = {
buffer.get()
}
+
+ override def setField(row: MutableRow, ordinal: Int, value: Byte) {
+ row.setByte(ordinal, value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = row.getByte(ordinal)
}
private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) {
@@ -136,6 +191,12 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) {
override def extract(buffer: ByteBuffer) = {
buffer.getShort()
}
+
+ override def setField(row: MutableRow, ordinal: Int, value: Short) {
+ row.setShort(ordinal, value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = row.getShort(ordinal)
}
private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
@@ -152,6 +213,12 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
buffer.get(stringBytes, 0, length)
new String(stringBytes)
}
+
+ override def setField(row: MutableRow, ordinal: Int, value: String) {
+ row.setString(ordinal, value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = row.getString(ordinal)
}
private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
@@ -173,15 +240,27 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
}
}
-private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](8, 16)
+private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](8, 16) {
+ override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) {
+ row(ordinal) = value
+ }
+
+ override def getField(row: Row, ordinal: Int) = row(ordinal).asInstanceOf[Array[Byte]]
+}
// Used to process generic objects (all types other than those listed above). Objects should be
// serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized
// byte array.
-private[sql] object GENERIC extends ByteArrayColumnType[DataType](9, 16)
+private[sql] object GENERIC extends ByteArrayColumnType[DataType](9, 16) {
+ override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) {
+ row(ordinal) = SparkSqlSerializer.deserialize[Any](value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = SparkSqlSerializer.serialize(row(ordinal))
+}
private[sql] object ColumnType {
- implicit def dataTypeToColumnType(dataType: DataType): ColumnType[_, _] = {
+ def apply(dataType: DataType): ColumnType[_, _] = {
dataType match {
case IntegerType => INT
case LongType => LONG
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/inMemoryColumnarOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index f853759e5a..8a24733047 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/inMemoryColumnarOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -21,9 +21,6 @@ import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Attribute}
import org.apache.spark.sql.execution.{SparkPlan, LeafNode}
import org.apache.spark.sql.Row
-/* Implicit conversions */
-import org.apache.spark.sql.columnar.ColumnType._
-
private[sql] case class InMemoryColumnarTableScan(attributes: Seq[Attribute], child: SparkPlan)
extends LeafNode {
@@ -32,8 +29,8 @@ private[sql] case class InMemoryColumnarTableScan(attributes: Seq[Attribute], ch
lazy val cachedColumnBuffers = {
val output = child.output
val cached = child.execute().mapPartitions { iterator =>
- val columnBuilders = output.map { a =>
- ColumnBuilder(a.dataType.typeId, 0, a.name)
+ val columnBuilders = output.map { attribute =>
+ ColumnBuilder(ColumnType(attribute.dataType).typeId, 0, attribute.name)
}.toArray
var row: Row = null
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala
index 2970c609b9..7d49ab07f7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala
@@ -29,7 +29,7 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor {
private var nextNullIndex: Int = _
private var pos: Int = 0
- abstract override def initialize() {
+ abstract override protected def initialize() {
nullsBuffer = underlyingBuffer.duplicate().order(ByteOrder.nativeOrder())
nullCount = nullsBuffer.getInt()
nextNullIndex = if (nullCount > 0) nullsBuffer.getInt() else -1
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
index 048d1f05c7..2a3b6fc1e4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
@@ -22,10 +22,18 @@ import java.nio.{ByteBuffer, ByteOrder}
import org.apache.spark.sql.Row
/**
- * Builds a nullable column. The byte buffer of a nullable column contains:
- * - 4 bytes for the null count (number of nulls)
- * - positions for each null, in ascending order
- * - the non-null data (column data type, compression type, data...)
+ * A stackable trait used for building byte buffer for a column containing null values. Memory
+ * layout of the final byte buffer is:
+ * {{{
+ * .----------------------- Column type ID (4 bytes)
+ * | .------------------- Null count N (4 bytes)
+ * | | .--------------- Null positions (4 x N bytes, empty if null count is zero)
+ * | | | .--------- Non-null elements
+ * V V V V
+ * +---+---+-----+---------+
+ * | | | ... | ... ... |
+ * +---+---+-----+---------+
+ * }}}
*/
private[sql] trait NullableColumnBuilder extends ColumnBuilder {
private var nulls: ByteBuffer = _
@@ -59,19 +67,8 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder {
nulls.limit(nullDataLen)
nulls.rewind()
- // Column type ID is moved to the front, follows the null count, then non-null data
- //
- // +---------+
- // | 4 bytes | Column type ID
- // +---------+
- // | 4 bytes | Null count
- // +---------+
- // | ... | Null positions (if null count is not zero)
- // +---------+
- // | ... | Non-null part (without column type ID)
- // +---------+
val buffer = ByteBuffer
- .allocate(4 + nullDataLen + nonNulls.limit)
+ .allocate(4 + 4 + nullDataLen + nonNulls.remaining())
.order(ByteOrder.nativeOrder())
.putInt(typeId)
.putInt(nullCount)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala
new file mode 100644
index 0000000000..878cb84de1
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar.compression
+
+import java.nio.ByteBuffer
+
+import org.apache.spark.sql.catalyst.types.NativeType
+import org.apache.spark.sql.columnar.{ColumnAccessor, NativeColumnAccessor}
+
+private[sql] trait CompressibleColumnAccessor[T <: NativeType] extends ColumnAccessor {
+ this: NativeColumnAccessor[T] =>
+
+ private var decoder: Decoder[T] = _
+
+ abstract override protected def initialize() = {
+ super.initialize()
+ decoder = CompressionScheme(underlyingBuffer.getInt()).decoder(buffer, columnType)
+ }
+
+ abstract override def extractSingle(buffer: ByteBuffer): T#JvmType = decoder.next()
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala
new file mode 100644
index 0000000000..3ac4b358dd
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar.compression
+
+import java.nio.{ByteBuffer, ByteOrder}
+
+import org.apache.spark.sql.{Logging, Row}
+import org.apache.spark.sql.catalyst.types.NativeType
+import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder}
+
+/**
+ * A stackable trait that builds optionally compressed byte buffer for a column. Memory layout of
+ * the final byte buffer is:
+ * {{{
+ * .--------------------------- Column type ID (4 bytes)
+ * | .----------------------- Null count N (4 bytes)
+ * | | .------------------- Null positions (4 x N bytes, empty if null count is zero)
+ * | | | .------------- Compression scheme ID (4 bytes)
+ * | | | | .--------- Compressed non-null elements
+ * V V V V V
+ * +---+---+-----+---+---------+
+ * | | | ... | | ... ... |
+ * +---+---+-----+---+---------+
+ * \-----------/ \-----------/
+ * header body
+ * }}}
+ */
+private[sql] trait CompressibleColumnBuilder[T <: NativeType]
+ extends ColumnBuilder with Logging {
+
+ this: NativeColumnBuilder[T] with WithCompressionSchemes =>
+
+ import CompressionScheme._
+
+ val compressionEncoders = schemes.filter(_.supports(columnType)).map(_.encoder)
+
+ protected def isWorthCompressing(encoder: Encoder) = {
+ encoder.compressionRatio < 0.8
+ }
+
+ private def gatherCompressibilityStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+
+ var i = 0
+ while (i < compressionEncoders.length) {
+ compressionEncoders(i).gatherCompressibilityStats(field, columnType)
+ i += 1
+ }
+ }
+
+ abstract override def appendFrom(row: Row, ordinal: Int) {
+ super.appendFrom(row, ordinal)
+ gatherCompressibilityStats(row, ordinal)
+ }
+
+ abstract override def build() = {
+ val rawBuffer = super.build()
+ val encoder = {
+ val candidate = compressionEncoders.minBy(_.compressionRatio)
+ if (isWorthCompressing(candidate)) candidate else PassThrough.encoder
+ }
+
+ val headerSize = columnHeaderSize(rawBuffer)
+ val compressedSize = if (encoder.compressedSize == 0) {
+ rawBuffer.limit - headerSize
+ } else {
+ encoder.compressedSize
+ }
+
+ // Reserves 4 bytes for compression scheme ID
+ val compressedBuffer = ByteBuffer
+ .allocate(headerSize + 4 + compressedSize)
+ .order(ByteOrder.nativeOrder)
+
+ copyColumnHeader(rawBuffer, compressedBuffer)
+
+ logger.info(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}")
+ encoder.compress(rawBuffer, compressedBuffer, columnType)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala
new file mode 100644
index 0000000000..d3a4ac8df9
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar.compression
+
+import java.nio.ByteBuffer
+
+import org.apache.spark.sql.catalyst.types.NativeType
+import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType}
+
+private[sql] trait Encoder {
+ def gatherCompressibilityStats[T <: NativeType](
+ value: T#JvmType,
+ columnType: ColumnType[T, T#JvmType]) {}
+
+ def compressedSize: Int
+
+ def uncompressedSize: Int
+
+ def compressionRatio: Double = {
+ if (uncompressedSize > 0) compressedSize.toDouble / uncompressedSize else 1.0
+ }
+
+ def compress[T <: NativeType](
+ from: ByteBuffer,
+ to: ByteBuffer,
+ columnType: ColumnType[T, T#JvmType]): ByteBuffer
+}
+
+private[sql] trait Decoder[T <: NativeType] extends Iterator[T#JvmType]
+
+private[sql] trait CompressionScheme {
+ def typeId: Int
+
+ def supports(columnType: ColumnType[_, _]): Boolean
+
+ def encoder: Encoder
+
+ def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T]
+}
+
+private[sql] trait WithCompressionSchemes {
+ def schemes: Seq[CompressionScheme]
+}
+
+private[sql] trait AllCompressionSchemes extends WithCompressionSchemes {
+ override val schemes: Seq[CompressionScheme] = {
+ Seq(PassThrough, RunLengthEncoding, DictionaryEncoding)
+ }
+}
+
+private[sql] object CompressionScheme {
+ def apply(typeId: Int): CompressionScheme = typeId match {
+ case PassThrough.typeId => PassThrough
+ case _ => throw new UnsupportedOperationException()
+ }
+
+ def copyColumnHeader(from: ByteBuffer, to: ByteBuffer) {
+ // Writes column type ID
+ to.putInt(from.getInt())
+
+ // Writes null count
+ val nullCount = from.getInt()
+ to.putInt(nullCount)
+
+ // Writes null positions
+ var i = 0
+ while (i < nullCount) {
+ to.putInt(from.getInt())
+ i += 1
+ }
+ }
+
+ def columnHeaderSize(columnBuffer: ByteBuffer): Int = {
+ val header = columnBuffer.duplicate()
+ val nullCount = header.getInt(4)
+ // Column type ID + null count + null positions
+ 4 + 4 + 4 * nullCount
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
new file mode 100644
index 0000000000..dc2c153faf
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
@@ -0,0 +1,288 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar.compression
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable
+import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.runtimeMirror
+
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+import org.apache.spark.sql.catalyst.types.NativeType
+import org.apache.spark.sql.columnar._
+
+private[sql] case object PassThrough extends CompressionScheme {
+ override val typeId = 0
+
+ override def supports(columnType: ColumnType[_, _]) = true
+
+ override def encoder = new this.Encoder
+
+ override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = {
+ new this.Decoder(buffer, columnType)
+ }
+
+ class Encoder extends compression.Encoder {
+ override def uncompressedSize = 0
+
+ override def compressedSize = 0
+
+ override def compress[T <: NativeType](
+ from: ByteBuffer,
+ to: ByteBuffer,
+ columnType: ColumnType[T, T#JvmType]) = {
+
+ // Writes compression type ID and copies raw contents
+ to.putInt(PassThrough.typeId).put(from).rewind()
+ to
+ }
+ }
+
+ class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T])
+ extends compression.Decoder[T] {
+
+ override def next() = columnType.extract(buffer)
+
+ override def hasNext = buffer.hasRemaining
+ }
+}
+
+private[sql] case object RunLengthEncoding extends CompressionScheme {
+ override def typeId = 1
+
+ override def encoder = new this.Encoder
+
+ override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = {
+ new this.Decoder(buffer, columnType)
+ }
+
+ override def supports(columnType: ColumnType[_, _]) = columnType match {
+ case INT | LONG | SHORT | BYTE | STRING | BOOLEAN => true
+ case _ => false
+ }
+
+ class Encoder extends compression.Encoder {
+ private var _uncompressedSize = 0
+ private var _compressedSize = 0
+
+ // Using `MutableRow` to store the last value to avoid boxing/unboxing cost.
+ private val lastValue = new GenericMutableRow(1)
+ private var lastRun = 0
+
+ override def uncompressedSize = _uncompressedSize
+
+ override def compressedSize = _compressedSize
+
+ override def gatherCompressibilityStats[T <: NativeType](
+ value: T#JvmType,
+ columnType: ColumnType[T, T#JvmType]) {
+
+ val actualSize = columnType.actualSize(value)
+ _uncompressedSize += actualSize
+
+ if (lastValue.isNullAt(0)) {
+ columnType.setField(lastValue, 0, value)
+ lastRun = 1
+ _compressedSize += actualSize + 4
+ } else {
+ if (columnType.getField(lastValue, 0) == value) {
+ lastRun += 1
+ } else {
+ _compressedSize += actualSize + 4
+ columnType.setField(lastValue, 0, value)
+ lastRun = 1
+ }
+ }
+ }
+
+ override def compress[T <: NativeType](
+ from: ByteBuffer,
+ to: ByteBuffer,
+ columnType: ColumnType[T, T#JvmType]) = {
+
+ to.putInt(RunLengthEncoding.typeId)
+
+ if (from.hasRemaining) {
+ var currentValue = columnType.extract(from)
+ var currentRun = 1
+
+ while (from.hasRemaining) {
+ val value = columnType.extract(from)
+
+ if (value == currentValue) {
+ currentRun += 1
+ } else {
+ // Writes current run
+ columnType.append(currentValue, to)
+ to.putInt(currentRun)
+
+ // Resets current run
+ currentValue = value
+ currentRun = 1
+ }
+ }
+
+ // Writes the last run
+ columnType.append(currentValue, to)
+ to.putInt(currentRun)
+ }
+
+ to.rewind()
+ to
+ }
+ }
+
+ class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T])
+ extends compression.Decoder[T] {
+
+ private var run = 0
+ private var valueCount = 0
+ private var currentValue: T#JvmType = _
+
+ override def next() = {
+ if (valueCount == run) {
+ currentValue = columnType.extract(buffer)
+ run = buffer.getInt()
+ valueCount = 1
+ } else {
+ valueCount += 1
+ }
+
+ currentValue
+ }
+
+ override def hasNext = buffer.hasRemaining
+ }
+}
+
+private[sql] case object DictionaryEncoding extends CompressionScheme {
+ override def typeId: Int = 2
+
+ // 32K unique values allowed
+ private val MAX_DICT_SIZE = Short.MaxValue - 1
+
+ override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = {
+ new this.Decoder[T](buffer, columnType)
+ }
+
+ override def encoder = new this.Encoder
+
+ override def supports(columnType: ColumnType[_, _]) = columnType match {
+ case INT | LONG | STRING => true
+ case _ => false
+ }
+
+ class Encoder extends compression.Encoder{
+ // Size of the input, uncompressed, in bytes. Note that we only count until the dictionary
+ // overflows.
+ private var _uncompressedSize = 0
+
+ // If the number of distinct elements is too large, we discard the use of dictionary encoding
+ // and set the overflow flag to true.
+ private var overflow = false
+
+ // Total number of elements.
+ private var count = 0
+
+ // The reverse mapping of _dictionary, i.e. mapping encoded integer to the value itself.
+ private var values = new mutable.ArrayBuffer[Any](1024)
+
+ // The dictionary that maps a value to the encoded short integer.
+ private val dictionary = mutable.HashMap.empty[Any, Short]
+
+ // Size of the serialized dictionary in bytes. Initialized to 4 since we need at least an `Int`
+ // to store dictionary element count.
+ private var dictionarySize = 4
+
+ override def gatherCompressibilityStats[T <: NativeType](
+ value: T#JvmType,
+ columnType: ColumnType[T, T#JvmType]) {
+
+ if (!overflow) {
+ val actualSize = columnType.actualSize(value)
+ count += 1
+ _uncompressedSize += actualSize
+
+ if (!dictionary.contains(value)) {
+ if (dictionary.size < MAX_DICT_SIZE) {
+ val clone = columnType.clone(value)
+ values += clone
+ dictionarySize += actualSize
+ dictionary(clone) = dictionary.size.toShort
+ } else {
+ overflow = true
+ values.clear()
+ dictionary.clear()
+ }
+ }
+ }
+ }
+
+ override def compress[T <: NativeType](
+ from: ByteBuffer,
+ to: ByteBuffer,
+ columnType: ColumnType[T, T#JvmType]) = {
+
+ if (overflow) {
+ throw new IllegalStateException(
+ "Dictionary encoding should not be used because of dictionary overflow.")
+ }
+
+ to.putInt(DictionaryEncoding.typeId)
+ .putInt(dictionary.size)
+
+ var i = 0
+ while (i < values.length) {
+ columnType.append(values(i).asInstanceOf[T#JvmType], to)
+ i += 1
+ }
+
+ while (from.hasRemaining) {
+ to.putShort(dictionary(columnType.extract(from)))
+ }
+
+ to.rewind()
+ to
+ }
+
+ override def uncompressedSize = _uncompressedSize
+
+ override def compressedSize = if (overflow) Int.MaxValue else dictionarySize + count * 2
+ }
+
+ class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T])
+ extends compression.Decoder[T] {
+
+ private val dictionary = {
+ // TODO Can we clean up this mess? Maybe move this to `DataType`?
+ implicit val classTag = {
+ val mirror = runtimeMirror(getClass.getClassLoader)
+ ClassTag[T#JvmType](mirror.runtimeClass(columnType.scalaTag.tpe))
+ }
+
+ Array.fill(buffer.getInt()) {
+ columnType.extract(buffer)
+ }
+ }
+
+ override def next() = dictionary(buffer.getShort())
+
+ override def hasNext = buffer.hasRemaining
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
new file mode 100644
index 0000000000..78640b876d
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.catalyst.types._
+
+class ColumnStatsSuite extends FunSuite {
+ testColumnStats(classOf[BooleanColumnStats], BOOLEAN)
+ testColumnStats(classOf[ByteColumnStats], BYTE)
+ testColumnStats(classOf[ShortColumnStats], SHORT)
+ testColumnStats(classOf[IntColumnStats], INT)
+ testColumnStats(classOf[LongColumnStats], LONG)
+ testColumnStats(classOf[FloatColumnStats], FLOAT)
+ testColumnStats(classOf[DoubleColumnStats], DOUBLE)
+ testColumnStats(classOf[StringColumnStats], STRING)
+
+ def testColumnStats[T <: NativeType, U <: NativeColumnStats[T]](
+ columnStatsClass: Class[U],
+ columnType: NativeColumnType[T]) {
+
+ val columnStatsName = columnStatsClass.getSimpleName
+
+ test(s"$columnStatsName: empty") {
+ val columnStats = columnStatsClass.newInstance()
+ expectResult(columnStats.initialBounds, "Wrong initial bounds") {
+ (columnStats.lowerBound, columnStats.upperBound)
+ }
+ }
+
+ test(s"$columnStatsName: non-empty") {
+ import ColumnarTestUtils._
+
+ val columnStats = columnStatsClass.newInstance()
+ val rows = Seq.fill(10)(makeRandomRow(columnType))
+ rows.foreach(columnStats.gatherStats(_, 0))
+
+ val values = rows.map(_.head.asInstanceOf[T#JvmType])
+ val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]]
+
+ expectResult(values.min(ordering), "Wrong lower bound")(columnStats.lowerBound)
+ expectResult(values.max(ordering), "Wrong upper bound")(columnStats.upperBound)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index 2d431affbc..1d3608ed2d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -19,46 +19,56 @@ package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
-import scala.util.Random
-
import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.columnar.ColumnarTestUtils._
import org.apache.spark.sql.execution.SparkSqlSerializer
class ColumnTypeSuite extends FunSuite {
- val columnTypes = Seq(INT, SHORT, LONG, BYTE, DOUBLE, FLOAT, STRING, BINARY, GENERIC)
+ val DEFAULT_BUFFER_SIZE = 512
test("defaultSize") {
- val defaultSize = Seq(4, 2, 8, 1, 8, 4, 8, 16, 16)
+ val checks = Map(
+ INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4,
+ BOOLEAN -> 1, STRING -> 8, BINARY -> 16, GENERIC -> 16)
- columnTypes.zip(defaultSize).foreach { case (columnType, size) =>
- assert(columnType.defaultSize === size)
+ checks.foreach { case (columnType, expectedSize) =>
+ expectResult(expectedSize, s"Wrong defaultSize for $columnType") {
+ columnType.defaultSize
+ }
}
}
test("actualSize") {
- val expectedSizes = Seq(4, 2, 8, 1, 8, 4, 4 + 5, 4 + 4, 4 + 11)
- val actualSizes = Seq(
- INT.actualSize(Int.MaxValue),
- SHORT.actualSize(Short.MaxValue),
- LONG.actualSize(Long.MaxValue),
- BYTE.actualSize(Byte.MaxValue),
- DOUBLE.actualSize(Double.MaxValue),
- FLOAT.actualSize(Float.MaxValue),
- STRING.actualSize("hello"),
- BINARY.actualSize(new Array[Byte](4)),
- GENERIC.actualSize(SparkSqlSerializer.serialize(Map(1 -> "a"))))
-
- expectedSizes.zip(actualSizes).foreach { case (expected, actual) =>
- assert(expected === actual)
+ def checkActualSize[T <: DataType, JvmType](
+ columnType: ColumnType[T, JvmType],
+ value: JvmType,
+ expected: Int) {
+
+ expectResult(expected, s"Wrong actualSize for $columnType") {
+ columnType.actualSize(value)
+ }
}
+
+ checkActualSize(INT, Int.MaxValue, 4)
+ checkActualSize(SHORT, Short.MaxValue, 2)
+ checkActualSize(LONG, Long.MaxValue, 8)
+ checkActualSize(BYTE, Byte.MaxValue, 1)
+ checkActualSize(DOUBLE, Double.MaxValue, 8)
+ checkActualSize(FLOAT, Float.MaxValue, 4)
+ checkActualSize(BOOLEAN, true, 1)
+ checkActualSize(STRING, "hello", 4 + 5)
+
+ val binary = Array.fill[Byte](4)(0: Byte)
+ checkActualSize(BINARY, binary, 4 + 4)
+
+ val generic = Map(1 -> "a")
+ checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 11)
}
- testNumericColumnType[BooleanType.type, Boolean](
+ testNativeColumnType[BooleanType.type](
BOOLEAN,
- Array.fill(4)(Random.nextBoolean()),
- ByteBuffer.allocate(32),
(buffer: ByteBuffer, v: Boolean) => {
buffer.put((if (v) 1 else 0).toByte)
},
@@ -66,105 +76,42 @@ class ColumnTypeSuite extends FunSuite {
buffer.get() == 1
})
- testNumericColumnType[IntegerType.type, Int](
- INT,
- Array.fill(4)(Random.nextInt()),
- ByteBuffer.allocate(32),
- (_: ByteBuffer).putInt(_),
- (_: ByteBuffer).getInt)
-
- testNumericColumnType[ShortType.type, Short](
- SHORT,
- Array.fill(4)(Random.nextInt(Short.MaxValue).asInstanceOf[Short]),
- ByteBuffer.allocate(32),
- (_: ByteBuffer).putShort(_),
- (_: ByteBuffer).getShort)
-
- testNumericColumnType[LongType.type, Long](
- LONG,
- Array.fill(4)(Random.nextLong()),
- ByteBuffer.allocate(64),
- (_: ByteBuffer).putLong(_),
- (_: ByteBuffer).getLong)
-
- testNumericColumnType[ByteType.type, Byte](
- BYTE,
- Array.fill(4)(Random.nextInt(Byte.MaxValue).asInstanceOf[Byte]),
- ByteBuffer.allocate(64),
- (_: ByteBuffer).put(_),
- (_: ByteBuffer).get)
-
- testNumericColumnType[DoubleType.type, Double](
- DOUBLE,
- Array.fill(4)(Random.nextDouble()),
- ByteBuffer.allocate(64),
- (_: ByteBuffer).putDouble(_),
- (_: ByteBuffer).getDouble)
-
- testNumericColumnType[FloatType.type, Float](
- FLOAT,
- Array.fill(4)(Random.nextFloat()),
- ByteBuffer.allocate(64),
- (_: ByteBuffer).putFloat(_),
- (_: ByteBuffer).getFloat)
-
- test("STRING") {
- val buffer = ByteBuffer.allocate(128)
- val seq = Array("hello", "world", "spark", "sql")
-
- seq.map(_.getBytes).foreach { bytes: Array[Byte] =>
- buffer.putInt(bytes.length).put(bytes)
- }
+ testNativeColumnType[IntegerType.type](INT, _.putInt(_), _.getInt)
- buffer.rewind()
- seq.foreach { s =>
- assert(s === STRING.extract(buffer))
- }
+ testNativeColumnType[ShortType.type](SHORT, _.putShort(_), _.getShort)
- buffer.rewind()
- seq.foreach(STRING.append(_, buffer))
+ testNativeColumnType[LongType.type](LONG, _.putLong(_), _.getLong)
- buffer.rewind()
- seq.foreach { s =>
- val length = buffer.getInt
- assert(length === s.getBytes.length)
+ testNativeColumnType[ByteType.type](BYTE, _.put(_), _.get)
+
+ testNativeColumnType[DoubleType.type](DOUBLE, _.putDouble(_), _.getDouble)
+
+ testNativeColumnType[FloatType.type](FLOAT, _.putFloat(_), _.getFloat)
+ testNativeColumnType[StringType.type](
+ STRING,
+ (buffer: ByteBuffer, string: String) => {
+ val bytes = string.getBytes()
+ buffer.putInt(bytes.length).put(string.getBytes)
+ },
+ (buffer: ByteBuffer) => {
+ val length = buffer.getInt()
val bytes = new Array[Byte](length)
buffer.get(bytes, 0, length)
- assert(s === new String(bytes))
- }
- }
-
- test("BINARY") {
- val buffer = ByteBuffer.allocate(128)
- val seq = Array.fill(4) {
- val bytes = new Array[Byte](4)
- Random.nextBytes(bytes)
- bytes
- }
+ new String(bytes)
+ })
- seq.foreach { bytes =>
+ testColumnType[BinaryType.type, Array[Byte]](
+ BINARY,
+ (buffer: ByteBuffer, bytes: Array[Byte]) => {
buffer.putInt(bytes.length).put(bytes)
- }
-
- buffer.rewind()
- seq.foreach { b =>
- assert(b === BINARY.extract(buffer))
- }
-
- buffer.rewind()
- seq.foreach(BINARY.append(_, buffer))
-
- buffer.rewind()
- seq.foreach { b =>
- val length = buffer.getInt
- assert(length === b.length)
-
+ },
+ (buffer: ByteBuffer) => {
+ val length = buffer.getInt()
val bytes = new Array[Byte](length)
buffer.get(bytes, 0, length)
- assert(b === bytes)
- }
- }
+ bytes
+ })
test("GENERIC") {
val buffer = ByteBuffer.allocate(512)
@@ -177,43 +124,58 @@ class ColumnTypeSuite extends FunSuite {
val length = buffer.getInt()
assert(length === serializedObj.length)
- val bytes = new Array[Byte](length)
- buffer.get(bytes, 0, length)
- assert(obj === SparkSqlSerializer.deserialize(bytes))
+ expectResult(obj, "Deserialized object didn't equal to the original object") {
+ val bytes = new Array[Byte](length)
+ buffer.get(bytes, 0, length)
+ SparkSqlSerializer.deserialize(bytes)
+ }
buffer.rewind()
buffer.putInt(serializedObj.length).put(serializedObj)
- buffer.rewind()
- assert(obj === SparkSqlSerializer.deserialize(GENERIC.extract(buffer)))
+ expectResult(obj, "Deserialized object didn't equal to the original object") {
+ buffer.rewind()
+ SparkSqlSerializer.deserialize(GENERIC.extract(buffer))
+ }
+ }
+
+ def testNativeColumnType[T <: NativeType](
+ columnType: NativeColumnType[T],
+ putter: (ByteBuffer, T#JvmType) => Unit,
+ getter: (ByteBuffer) => T#JvmType) {
+
+ testColumnType[T, T#JvmType](columnType, putter, getter)
}
- def testNumericColumnType[T <: DataType, JvmType](
+ def testColumnType[T <: DataType, JvmType](
columnType: ColumnType[T, JvmType],
- seq: Seq[JvmType],
- buffer: ByteBuffer,
putter: (ByteBuffer, JvmType) => Unit,
getter: (ByteBuffer) => JvmType) {
- val columnTypeName = columnType.getClass.getSimpleName.stripSuffix("$")
+ val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE)
+ val seq = (0 until 4).map(_ => makeRandomValue(columnType))
- test(s"$columnTypeName.extract") {
+ test(s"$columnType.extract") {
buffer.rewind()
seq.foreach(putter(buffer, _))
buffer.rewind()
- seq.foreach { i =>
- assert(i === columnType.extract(buffer))
+ seq.foreach { expected =>
+ assert(
+ expected === columnType.extract(buffer),
+ "Extracted value didn't equal to the original one")
}
}
- test(s"$columnTypeName.append") {
+ test(s"$columnType.append") {
buffer.rewind()
seq.foreach(columnType.append(_, buffer))
buffer.rewind()
- seq.foreach { i =>
- assert(i === getter(buffer))
+ seq.foreach { expected =>
+ assert(
+ expected === getter(buffer),
+ "Extracted value didn't equal to the original one")
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarQuerySuite.scala
index 928851a385..70b2e85173 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarQuerySuite.scala
@@ -17,11 +17,11 @@
package org.apache.spark.sql.columnar
+import org.apache.spark.sql.{QueryTest, TestData}
import org.apache.spark.sql.execution.SparkLogicalPlan
import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{TestData, DslQuerySuite}
-class ColumnarQuerySuite extends DslQuerySuite {
+class ColumnarQuerySuite extends QueryTest {
import TestData._
import TestSQLContext._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestData.scala
deleted file mode 100644
index ddcdede8d1..0000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestData.scala
+++ /dev/null
@@ -1,55 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.columnar
-
-import scala.util.Random
-
-import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
-
-// TODO Enrich test data
-object ColumnarTestData {
- object GenericMutableRow {
- def apply(values: Any*) = {
- val row = new GenericMutableRow(values.length)
- row.indices.foreach { i =>
- row(i) = values(i)
- }
- row
- }
- }
-
- def randomBytes(length: Int) = {
- val bytes = new Array[Byte](length)
- Random.nextBytes(bytes)
- bytes
- }
-
- val nonNullRandomRow = GenericMutableRow(
- Random.nextInt(),
- Random.nextLong(),
- Random.nextFloat(),
- Random.nextDouble(),
- Random.nextBoolean(),
- Random.nextInt(Byte.MaxValue).asInstanceOf[Byte],
- Random.nextInt(Short.MaxValue).asInstanceOf[Short],
- Random.nextString(Random.nextInt(64)),
- randomBytes(Random.nextInt(64)),
- Map(Random.nextInt() -> Random.nextString(4)))
-
- val nullRow = GenericMutableRow(Seq.fill(10)(null): _*)
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
new file mode 100644
index 0000000000..04bdc43d95
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar
+
+import scala.collection.immutable.HashSet
+import scala.util.Random
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+import org.apache.spark.sql.catalyst.types.{DataType, NativeType}
+
+object ColumnarTestUtils {
+ def makeNullRow(length: Int) = {
+ val row = new GenericMutableRow(length)
+ (0 until length).foreach(row.setNullAt)
+ row
+ }
+
+ def makeRandomValue[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]): JvmType = {
+ def randomBytes(length: Int) = {
+ val bytes = new Array[Byte](length)
+ Random.nextBytes(bytes)
+ bytes
+ }
+
+ (columnType match {
+ case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
+ case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort
+ case INT => Random.nextInt()
+ case LONG => Random.nextLong()
+ case FLOAT => Random.nextFloat()
+ case DOUBLE => Random.nextDouble()
+ case STRING => Random.nextString(Random.nextInt(32))
+ case BOOLEAN => Random.nextBoolean()
+ case BINARY => randomBytes(Random.nextInt(32))
+ case _ =>
+ // Using a random one-element map instead of an arbitrary object
+ Map(Random.nextInt() -> Random.nextString(Random.nextInt(32)))
+ }).asInstanceOf[JvmType]
+ }
+
+ def makeRandomValues(
+ head: ColumnType[_ <: DataType, _],
+ tail: ColumnType[_ <: DataType, _]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail)
+
+ def makeRandomValues(columnTypes: Seq[ColumnType[_ <: DataType, _]]): Seq[Any] = {
+ columnTypes.map(makeRandomValue(_))
+ }
+
+ def makeUniqueRandomValues[T <: DataType, JvmType](
+ columnType: ColumnType[T, JvmType],
+ count: Int): Seq[JvmType] = {
+
+ Iterator.iterate(HashSet.empty[JvmType]) { set =>
+ set + Iterator.continually(makeRandomValue(columnType)).filterNot(set.contains).next()
+ }.drop(count).next().toSeq
+ }
+
+ def makeRandomRow(
+ head: ColumnType[_ <: DataType, _],
+ tail: ColumnType[_ <: DataType, _]*): Row = makeRandomRow(Seq(head) ++ tail)
+
+ def makeRandomRow(columnTypes: Seq[ColumnType[_ <: DataType, _]]): Row = {
+ val row = new GenericMutableRow(columnTypes.length)
+ makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) =>
+ row(index) = value
+ }
+ row
+ }
+
+ def makeUniqueValuesAndSingleValueRows[T <: NativeType](
+ columnType: NativeColumnType[T],
+ count: Int) = {
+
+ val values = makeUniqueRandomValues(columnType, count)
+ val rows = values.map { value =>
+ val row = new GenericMutableRow(1)
+ row(0) = value
+ row
+ }
+
+ (values, rows)
+ }
+
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
index d413d483f4..4a21eb6201 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
@@ -17,12 +17,29 @@
package org.apache.spark.sql.columnar
+import java.nio.ByteBuffer
+
import org.scalatest.FunSuite
-import org.apache.spark.sql.catalyst.types.DataType
+
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+import org.apache.spark.sql.catalyst.types.DataType
+
+class TestNullableColumnAccessor[T <: DataType, JvmType](
+ buffer: ByteBuffer,
+ columnType: ColumnType[T, JvmType])
+ extends BasicColumnAccessor(buffer, columnType)
+ with NullableColumnAccessor
+
+object TestNullableColumnAccessor {
+ def apply[T <: DataType, JvmType](buffer: ByteBuffer, columnType: ColumnType[T, JvmType]) = {
+ // Skips the column type ID
+ buffer.getInt()
+ new TestNullableColumnAccessor(buffer, columnType)
+ }
+}
class NullableColumnAccessorSuite extends FunSuite {
- import ColumnarTestData._
+ import ColumnarTestUtils._
Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC).foreach {
testNullableColumnAccessor(_)
@@ -30,30 +47,32 @@ class NullableColumnAccessorSuite extends FunSuite {
def testNullableColumnAccessor[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) {
val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
+ val nullRow = makeNullRow(1)
- test(s"$typeName accessor: empty column") {
- val builder = ColumnBuilder(columnType.typeId, 4)
- val accessor = ColumnAccessor(builder.build())
+ test(s"Nullable $typeName column accessor: empty column") {
+ val builder = TestNullableColumnBuilder(columnType)
+ val accessor = TestNullableColumnAccessor(builder.build(), columnType)
assert(!accessor.hasNext)
}
- test(s"$typeName accessor: access null values") {
- val builder = ColumnBuilder(columnType.typeId, 4)
+ test(s"Nullable $typeName column accessor: access null values") {
+ val builder = TestNullableColumnBuilder(columnType)
+ val randomRow = makeRandomRow(columnType)
(0 until 4).foreach { _ =>
- builder.appendFrom(nonNullRandomRow, columnType.typeId)
- builder.appendFrom(nullRow, columnType.typeId)
+ builder.appendFrom(randomRow, 0)
+ builder.appendFrom(nullRow, 0)
}
- val accessor = ColumnAccessor(builder.build())
+ val accessor = TestNullableColumnAccessor(builder.build(), columnType)
val row = new GenericMutableRow(1)
(0 until 4).foreach { _ =>
accessor.extractTo(row, 0)
- assert(row(0) === nonNullRandomRow(columnType.typeId))
+ assert(row(0) === randomRow(0))
accessor.extractTo(row, 0)
- assert(row(0) === null)
+ assert(row.isNullAt(0))
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
index 5222a47e1a..d9d1e1bfdd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
@@ -19,63 +19,71 @@ package org.apache.spark.sql.columnar
import org.scalatest.FunSuite
-import org.apache.spark.sql.catalyst.types.DataType
+import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.execution.SparkSqlSerializer
+class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType])
+ extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], columnType)
+ with NullableColumnBuilder
+
+object TestNullableColumnBuilder {
+ def apply[T <: DataType, JvmType](columnType: ColumnType[T, JvmType], initialSize: Int = 0) = {
+ val builder = new TestNullableColumnBuilder(columnType)
+ builder.initialize(initialSize)
+ builder
+ }
+}
+
class NullableColumnBuilderSuite extends FunSuite {
- import ColumnarTestData._
+ import ColumnarTestUtils._
Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC).foreach {
testNullableColumnBuilder(_)
}
def testNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) {
- val columnBuilder = ColumnBuilder(columnType.typeId)
val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
test(s"$typeName column builder: empty column") {
- columnBuilder.initialize(4)
-
+ val columnBuilder = TestNullableColumnBuilder(columnType)
val buffer = columnBuilder.build()
- // For column type ID
- assert(buffer.getInt() === columnType.typeId)
- // For null count
- assert(buffer.getInt === 0)
+ expectResult(columnType.typeId, "Wrong column type ID")(buffer.getInt())
+ expectResult(0, "Wrong null count")(buffer.getInt())
assert(!buffer.hasRemaining)
}
test(s"$typeName column builder: buffer size auto growth") {
- columnBuilder.initialize(4)
+ val columnBuilder = TestNullableColumnBuilder(columnType)
+ val randomRow = makeRandomRow(columnType)
- (0 until 4) foreach { _ =>
- columnBuilder.appendFrom(nonNullRandomRow, columnType.typeId)
+ (0 until 4).foreach { _ =>
+ columnBuilder.appendFrom(randomRow, 0)
}
val buffer = columnBuilder.build()
- // For column type ID
- assert(buffer.getInt() === columnType.typeId)
- // For null count
- assert(buffer.getInt() === 0)
+ expectResult(columnType.typeId, "Wrong column type ID")(buffer.getInt())
+ expectResult(0, "Wrong null count")(buffer.getInt())
}
test(s"$typeName column builder: null values") {
- columnBuilder.initialize(4)
+ val columnBuilder = TestNullableColumnBuilder(columnType)
+ val randomRow = makeRandomRow(columnType)
+ val nullRow = makeNullRow(1)
- (0 until 4) foreach { _ =>
- columnBuilder.appendFrom(nonNullRandomRow, columnType.typeId)
- columnBuilder.appendFrom(nullRow, columnType.typeId)
+ (0 until 4).foreach { _ =>
+ columnBuilder.appendFrom(randomRow, 0)
+ columnBuilder.appendFrom(nullRow, 0)
}
val buffer = columnBuilder.build()
- // For column type ID
- assert(buffer.getInt() === columnType.typeId)
- // For null count
- assert(buffer.getInt() === 4)
+ expectResult(columnType.typeId, "Wrong column type ID")(buffer.getInt())
+ expectResult(4, "Wrong null count")(buffer.getInt())
+
// For null positions
- (1 to 7 by 2).foreach(i => assert(buffer.getInt() === i))
+ (1 to 7 by 2).foreach(expectResult(_, "Wrong null position")(buffer.getInt()))
// For non-null values
(0 until 4).foreach { _ =>
@@ -84,7 +92,8 @@ class NullableColumnBuilderSuite extends FunSuite {
} else {
columnType.extract(buffer)
}
- assert(actual === nonNullRandomRow(columnType.typeId))
+
+ assert(actual === randomRow(0), "Extracted value didn't equal to the original one")
}
assert(!buffer.hasRemaining)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
new file mode 100644
index 0000000000..184691ab5b
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar.compression
+
+import java.nio.ByteBuffer
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.catalyst.types.NativeType
+import org.apache.spark.sql.columnar._
+import org.apache.spark.sql.columnar.ColumnarTestUtils._
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+
+class DictionaryEncodingSuite extends FunSuite {
+ testDictionaryEncoding(new IntColumnStats, INT)
+ testDictionaryEncoding(new LongColumnStats, LONG)
+ testDictionaryEncoding(new StringColumnStats, STRING)
+
+ def testDictionaryEncoding[T <: NativeType](
+ columnStats: NativeColumnStats[T],
+ columnType: NativeColumnType[T]) {
+
+ val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
+
+ def buildDictionary(buffer: ByteBuffer) = {
+ (0 until buffer.getInt()).map(columnType.extract(buffer) -> _.toShort).toMap
+ }
+
+ test(s"$DictionaryEncoding with $typeName: simple case") {
+ // -------------
+ // Tests encoder
+ // -------------
+
+ val builder = TestCompressibleColumnBuilder(columnStats, columnType, DictionaryEncoding)
+ val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, 2)
+
+ builder.initialize(0)
+ builder.appendFrom(rows(0), 0)
+ builder.appendFrom(rows(1), 0)
+ builder.appendFrom(rows(0), 0)
+ builder.appendFrom(rows(1), 0)
+
+ val buffer = builder.build()
+ val headerSize = CompressionScheme.columnHeaderSize(buffer)
+ // 4 extra bytes for dictionary size
+ val dictionarySize = 4 + values.map(columnType.actualSize).sum
+ // 4 `Short`s, 2 bytes each
+ val compressedSize = dictionarySize + 2 * 4
+ // 4 extra bytes for compression scheme type ID
+ expectResult(headerSize + 4 + compressedSize, "Wrong buffer capacity")(buffer.capacity)
+
+ // Skips column header
+ buffer.position(headerSize)
+ expectResult(DictionaryEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt())
+
+ val dictionary = buildDictionary(buffer)
+ Array[Short](0, 1).foreach { i =>
+ expectResult(i, "Wrong dictionary entry")(dictionary(values(i)))
+ }
+
+ Array[Short](0, 1, 0, 1).foreach {
+ expectResult(_, "Wrong column element value")(buffer.getShort())
+ }
+
+ // -------------
+ // Tests decoder
+ // -------------
+
+ // Rewinds, skips column header and 4 more bytes for compression scheme ID
+ buffer.rewind().position(headerSize + 4)
+
+ val decoder = new DictionaryEncoding.Decoder[T](buffer, columnType)
+
+ Array[Short](0, 1, 0, 1).foreach { i =>
+ expectResult(values(i), "Wrong decoded value")(decoder.next())
+ }
+
+ assert(!decoder.hasNext)
+ }
+ }
+
+ test(s"$DictionaryEncoding: overflow") {
+ val builder = TestCompressibleColumnBuilder(new IntColumnStats, INT, DictionaryEncoding)
+ builder.initialize(0)
+
+ (0 to Short.MaxValue).foreach { n =>
+ val row = new GenericMutableRow(1)
+ row.setInt(0, n)
+ builder.appendFrom(row, 0)
+ }
+
+ withClue("Dictionary overflowed, encoding should fail") {
+ intercept[Throwable] {
+ builder.build()
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
new file mode 100644
index 0000000000..2089ad120d
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar.compression
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.catalyst.types.NativeType
+import org.apache.spark.sql.columnar._
+import org.apache.spark.sql.columnar.ColumnarTestUtils._
+
+class RunLengthEncodingSuite extends FunSuite {
+ testRunLengthEncoding(new BooleanColumnStats, BOOLEAN)
+ testRunLengthEncoding(new ByteColumnStats, BYTE)
+ testRunLengthEncoding(new ShortColumnStats, SHORT)
+ testRunLengthEncoding(new IntColumnStats, INT)
+ testRunLengthEncoding(new LongColumnStats, LONG)
+ testRunLengthEncoding(new StringColumnStats, STRING)
+
+ def testRunLengthEncoding[T <: NativeType](
+ columnStats: NativeColumnStats[T],
+ columnType: NativeColumnType[T]) {
+
+ val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
+
+ test(s"$RunLengthEncoding with $typeName: simple case") {
+ // -------------
+ // Tests encoder
+ // -------------
+
+ val builder = TestCompressibleColumnBuilder(columnStats, columnType, RunLengthEncoding)
+ val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, 2)
+
+ builder.initialize(0)
+ builder.appendFrom(rows(0), 0)
+ builder.appendFrom(rows(0), 0)
+ builder.appendFrom(rows(1), 0)
+ builder.appendFrom(rows(1), 0)
+
+ val buffer = builder.build()
+ val headerSize = CompressionScheme.columnHeaderSize(buffer)
+ // 4 extra bytes each run for run length
+ val compressedSize = values.map(columnType.actualSize(_) + 4).sum
+ // 4 extra bytes for compression scheme type ID
+ expectResult(headerSize + 4 + compressedSize, "Wrong buffer capacity")(buffer.capacity)
+
+ // Skips column header
+ buffer.position(headerSize)
+ expectResult(RunLengthEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt())
+
+ Array(0, 1).foreach { i =>
+ expectResult(values(i), "Wrong column element value")(columnType.extract(buffer))
+ expectResult(2, "Wrong run length")(buffer.getInt())
+ }
+
+ // -------------
+ // Tests decoder
+ // -------------
+
+ // Rewinds, skips column header and 4 more bytes for compression scheme ID
+ buffer.rewind().position(headerSize + 4)
+
+ val decoder = new RunLengthEncoding.Decoder[T](buffer, columnType)
+
+ Array(0, 0, 1, 1).foreach { i =>
+ expectResult(values(i), "Wrong decoded value")(decoder.next())
+ }
+
+ assert(!decoder.hasNext)
+ }
+
+ test(s"$RunLengthEncoding with $typeName: run length == 1") {
+ // -------------
+ // Tests encoder
+ // -------------
+
+ val builder = TestCompressibleColumnBuilder(columnStats, columnType, RunLengthEncoding)
+ val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, 2)
+
+ builder.initialize(0)
+ builder.appendFrom(rows(0), 0)
+ builder.appendFrom(rows(1), 0)
+
+ val buffer = builder.build()
+ val headerSize = CompressionScheme.columnHeaderSize(buffer)
+ // 4 bytes each run for run length
+ val compressedSize = values.map(columnType.actualSize(_) + 4).sum
+ // 4 bytes for compression scheme type ID
+ expectResult(headerSize + 4 + compressedSize, "Wrong buffer capacity")(buffer.capacity)
+
+ // Skips column header
+ buffer.position(headerSize)
+ expectResult(RunLengthEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt())
+
+ Array(0, 1).foreach { i =>
+ expectResult(values(i), "Wrong column element value")(columnType.extract(buffer))
+ expectResult(1, "Wrong run length")(buffer.getInt())
+ }
+
+ // -------------
+ // Tests decoder
+ // -------------
+
+ // Rewinds, skips column header and 4 more bytes for compression scheme ID
+ buffer.rewind().position(headerSize + 4)
+
+ val decoder = new RunLengthEncoding.Decoder[T](buffer, columnType)
+
+ Array(0, 1).foreach { i =>
+ expectResult(values(i), "Wrong decoded value")(decoder.next())
+ }
+
+ assert(!decoder.hasNext)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
new file mode 100644
index 0000000000..e0ec812863
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar.compression
+
+import org.apache.spark.sql.catalyst.types.NativeType
+import org.apache.spark.sql.columnar._
+
+class TestCompressibleColumnBuilder[T <: NativeType](
+ override val columnStats: NativeColumnStats[T],
+ override val columnType: NativeColumnType[T],
+ override val schemes: Seq[CompressionScheme])
+ extends NativeColumnBuilder(columnStats, columnType)
+ with NullableColumnBuilder
+ with CompressibleColumnBuilder[T] {
+
+ override protected def isWorthCompressing(encoder: Encoder) = true
+}
+
+object TestCompressibleColumnBuilder {
+ def apply[T <: NativeType](
+ columnStats: NativeColumnStats[T],
+ columnType: NativeColumnType[T],
+ scheme: CompressionScheme) = {
+
+ new TestCompressibleColumnBuilder(columnStats, columnType, Seq(scheme))
+ }
+}
+