aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-07-27 13:40:50 -0700
committerReynold Xin <rxin@databricks.com>2015-07-27 13:40:50 -0700
commit3ab7525dceeb1c2f3c21efb1ee5a9c8bb0fd0c13 (patch)
treebb8f0c945cac9177aba34eb87b6318756f345f2b /sql
parent8e7d2bee23dad1535846dae2dc31e35058db16cd (diff)
downloadspark-3ab7525dceeb1c2f3c21efb1ee5a9c8bb0fd0c13.tar.gz
spark-3ab7525dceeb1c2f3c21efb1ee5a9c8bb0fd0c13.tar.bz2
spark-3ab7525dceeb1c2f3c21efb1ee5a9c8bb0fd0c13.zip
[SPARK-9355][SQL] Remove InternalRow.get generic getter call in columnar cache code
Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7673 from cloud-fan/row-generic-getter-columnar and squashes the following commits: 88b1170 [Wenchen Fan] fix style eeae712 [Wenchen Fan] Remove Internal.get generic getter call in columnar cache code
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala49
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala30
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala21
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala2
12 files changed, 107 insertions, 95 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
index 931469bed6..4c29a09321 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
@@ -41,9 +41,9 @@ private[sql] trait ColumnAccessor {
protected def underlyingBuffer: ByteBuffer
}
-private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType](
+private[sql] abstract class BasicColumnAccessor[JvmType](
protected val buffer: ByteBuffer,
- protected val columnType: ColumnType[T, JvmType])
+ protected val columnType: ColumnType[JvmType])
extends ColumnAccessor {
protected def initialize() {}
@@ -93,14 +93,14 @@ private[sql] class StringColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, STRING)
private[sql] class BinaryColumnAccessor(buffer: ByteBuffer)
- extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer, BINARY)
+ extends BasicColumnAccessor[Array[Byte]](buffer, BINARY)
with NullableColumnAccessor
private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int)
extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale))
-private[sql] class GenericColumnAccessor(buffer: ByteBuffer)
- extends BasicColumnAccessor[DataType, Array[Byte]](buffer, GENERIC)
+private[sql] class GenericColumnAccessor(buffer: ByteBuffer, dataType: DataType)
+ extends BasicColumnAccessor[Array[Byte]](buffer, GENERIC(dataType))
with NullableColumnAccessor
private[sql] class DateColumnAccessor(buffer: ByteBuffer)
@@ -131,7 +131,7 @@ private[sql] object ColumnAccessor {
case BinaryType => new BinaryColumnAccessor(dup)
case DecimalType.Fixed(precision, scale) if precision < 19 =>
new FixedDecimalColumnAccessor(dup, precision, scale)
- case _ => new GenericColumnAccessor(dup)
+ case other => new GenericColumnAccessor(dup, other)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
index 087c522397..454b7b91a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
@@ -46,9 +46,9 @@ private[sql] trait ColumnBuilder {
def build(): ByteBuffer
}
-private[sql] class BasicColumnBuilder[T <: DataType, JvmType](
+private[sql] class BasicColumnBuilder[JvmType](
val columnStats: ColumnStats,
- val columnType: ColumnType[T, JvmType])
+ val columnType: ColumnType[JvmType])
extends ColumnBuilder {
protected var columnName: String = _
@@ -78,16 +78,16 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType](
}
}
-private[sql] abstract class ComplexColumnBuilder[T <: DataType, JvmType](
+private[sql] abstract class ComplexColumnBuilder[JvmType](
columnStats: ColumnStats,
- columnType: ColumnType[T, JvmType])
- extends BasicColumnBuilder[T, JvmType](columnStats, columnType)
+ columnType: ColumnType[JvmType])
+ extends BasicColumnBuilder[JvmType](columnStats, columnType)
with NullableColumnBuilder
private[sql] abstract class NativeColumnBuilder[T <: AtomicType](
override val columnStats: ColumnStats,
override val columnType: NativeColumnType[T])
- extends BasicColumnBuilder[T, T#InternalType](columnStats, columnType)
+ extends BasicColumnBuilder[T#InternalType](columnStats, columnType)
with NullableColumnBuilder
with AllCompressionSchemes
with CompressibleColumnBuilder[T]
@@ -118,8 +118,8 @@ private[sql] class FixedDecimalColumnBuilder(
FIXED_DECIMAL(precision, scale))
// TODO (lian) Add support for array, struct and map
-private[sql] class GenericColumnBuilder
- extends ComplexColumnBuilder(new GenericColumnStats, GENERIC)
+private[sql] class GenericColumnBuilder(dataType: DataType)
+ extends ComplexColumnBuilder(new GenericColumnStats(dataType), GENERIC(dataType))
private[sql] class DateColumnBuilder extends NativeColumnBuilder(new DateColumnStats, DATE)
@@ -164,7 +164,7 @@ private[sql] object ColumnBuilder {
case BinaryType => new BinaryColumnBuilder
case DecimalType.Fixed(precision, scale) if precision < 19 =>
new FixedDecimalColumnBuilder(precision, scale)
- case _ => new GenericColumnBuilder
+ case other => new GenericColumnBuilder(other)
}
builder.initialize(initialSize, columnName, useCompression)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
index 7c63179af6..32a84b2676 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -252,11 +252,13 @@ private[sql] class FixedDecimalColumnStats extends ColumnStats {
InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
-private[sql] class GenericColumnStats extends ColumnStats {
+private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats {
+ val columnType = GENERIC(dataType)
+
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
- sizeInBytes += GENERIC.actualSize(row, ordinal)
+ sizeInBytes += columnType.actualSize(row, ordinal)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index c0ca52751b..2863f6c230 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -31,14 +31,18 @@ import org.apache.spark.unsafe.types.UTF8String
* An abstract class that represents type of a column. Used to append/extract Java objects into/from
* the underlying [[ByteBuffer]] of a column.
*
- * @param typeId A unique ID representing the type.
- * @param defaultSize Default size in bytes for one element of type T (e.g. 4 for `Int`).
- * @tparam T Scala data type for the column.
* @tparam JvmType Underlying Java type to represent the elements.
*/
-private[sql] sealed abstract class ColumnType[T <: DataType, JvmType](
- val typeId: Int,
- val defaultSize: Int) {
+private[sql] sealed abstract class ColumnType[JvmType] {
+
+ // The catalyst data type of this column.
+ def dataType: DataType
+
+ // A unique ID representing the type.
+ def typeId: Int
+
+ // Default size in bytes for one element of type T (e.g. 4 for `Int`).
+ def defaultSize: Int
/**
* Extracts a value out of the buffer at the buffer's current position.
@@ -90,7 +94,7 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType](
* boxing/unboxing costs whenever possible.
*/
def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
- to(toOrdinal) = from.get(fromOrdinal)
+ to.update(toOrdinal, from.get(fromOrdinal, dataType))
}
/**
@@ -103,9 +107,9 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType](
private[sql] abstract class NativeColumnType[T <: AtomicType](
val dataType: T,
- typeId: Int,
- defaultSize: Int)
- extends ColumnType[T, T#InternalType](typeId, defaultSize) {
+ val typeId: Int,
+ val defaultSize: Int)
+ extends ColumnType[T#InternalType] {
/**
* Scala TypeTag. Can be used to create primitive arrays and hash tables.
@@ -400,10 +404,10 @@ private[sql] object FIXED_DECIMAL {
val defaultSize = 8
}
-private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
- typeId: Int,
- defaultSize: Int)
- extends ColumnType[T, Array[Byte]](typeId, defaultSize) {
+private[sql] sealed abstract class ByteArrayColumnType(
+ val typeId: Int,
+ val defaultSize: Int)
+ extends ColumnType[Array[Byte]] {
override def actualSize(row: InternalRow, ordinal: Int): Int = {
getField(row, ordinal).length + 4
@@ -421,9 +425,12 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
}
}
-private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](11, 16) {
+private[sql] object BINARY extends ByteArrayColumnType(11, 16) {
+
+ def dataType: DataType = BooleanType
+
override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = {
- row(ordinal) = value
+ row.update(ordinal, value)
}
override def getField(row: InternalRow, ordinal: Int): Array[Byte] = {
@@ -434,18 +441,18 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](11, 16)
// Used to process generic objects (all types other than those listed above). Objects should be
// serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized
// byte array.
-private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) {
+private[sql] case class GENERIC(dataType: DataType) extends ByteArrayColumnType(12, 16) {
override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = {
- row(ordinal) = SparkSqlSerializer.deserialize[Any](value)
+ row.update(ordinal, SparkSqlSerializer.deserialize[Any](value))
}
override def getField(row: InternalRow, ordinal: Int): Array[Byte] = {
- SparkSqlSerializer.serialize(row.get(ordinal))
+ SparkSqlSerializer.serialize(row.get(ordinal, dataType))
}
}
private[sql] object ColumnType {
- def apply(dataType: DataType): ColumnType[_, _] = {
+ def apply(dataType: DataType): ColumnType[_] = {
dataType match {
case BooleanType => BOOLEAN
case ByteType => BYTE
@@ -460,7 +467,7 @@ private[sql] object ColumnType {
case BinaryType => BINARY
case DecimalType.Fixed(precision, scale) if precision < 19 =>
FIXED_DECIMAL(precision, scale)
- case _ => GENERIC
+ case other => GENERIC(other)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala
index 4eaec6d853..b1ef9b2ef7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala
@@ -46,7 +46,7 @@ private[sql] trait Decoder[T <: AtomicType] {
private[sql] trait CompressionScheme {
def typeId: Int
- def supports(columnType: ColumnType[_, _]): Boolean
+ def supports(columnType: ColumnType[_]): Boolean
def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
index 6150df6930..c91d960a09 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
@@ -32,7 +32,7 @@ import org.apache.spark.util.Utils
private[sql] case object PassThrough extends CompressionScheme {
override val typeId = 0
- override def supports(columnType: ColumnType[_, _]): Boolean = true
+ override def supports(columnType: ColumnType[_]): Boolean = true
override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = {
new this.Encoder[T](columnType)
@@ -78,7 +78,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme {
new this.Decoder(buffer, columnType)
}
- override def supports(columnType: ColumnType[_, _]): Boolean = columnType match {
+ override def supports(columnType: ColumnType[_]): Boolean = columnType match {
case INT | LONG | SHORT | BYTE | STRING | BOOLEAN => true
case _ => false
}
@@ -128,7 +128,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme {
while (from.hasRemaining) {
columnType.extract(from, value, 0)
- if (value.get(0) == currentValue.get(0)) {
+ if (value.get(0, columnType.dataType) == currentValue.get(0, columnType.dataType)) {
currentRun += 1
} else {
// Writes current run
@@ -189,7 +189,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme {
new this.Encoder[T](columnType)
}
- override def supports(columnType: ColumnType[_, _]): Boolean = columnType match {
+ override def supports(columnType: ColumnType[_]): Boolean = columnType match {
case INT | LONG | STRING => true
case _ => false
}
@@ -304,7 +304,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme {
(new this.Encoder).asInstanceOf[compression.Encoder[T]]
}
- override def supports(columnType: ColumnType[_, _]): Boolean = columnType == BOOLEAN
+ override def supports(columnType: ColumnType[_]): Boolean = columnType == BOOLEAN
class Encoder extends compression.Encoder[BooleanType.type] {
private var _uncompressedSize = 0
@@ -392,7 +392,7 @@ private[sql] case object IntDelta extends CompressionScheme {
(new Encoder).asInstanceOf[compression.Encoder[T]]
}
- override def supports(columnType: ColumnType[_, _]): Boolean = columnType == INT
+ override def supports(columnType: ColumnType[_]): Boolean = columnType == INT
class Encoder extends compression.Encoder[IntegerType.type] {
protected var _compressedSize: Int = 0
@@ -472,7 +472,7 @@ private[sql] case object LongDelta extends CompressionScheme {
(new Encoder).asInstanceOf[compression.Encoder[T]]
}
- override def supports(columnType: ColumnType[_, _]): Boolean = columnType == LONG
+ override def supports(columnType: ColumnType[_]): Boolean = columnType == LONG
class Encoder extends compression.Encoder[LongType.type] {
protected var _compressedSize: Int = 0
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index 31e7b0e72e..4499a72070 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -58,15 +58,15 @@ class ColumnStatsSuite extends SparkFunSuite {
val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
rows.foreach(columnStats.gatherStats(_, 0))
- val values = rows.take(10).map(_.get(0).asInstanceOf[T#InternalType])
+ val values = rows.take(10).map(_.get(0, columnType.dataType).asInstanceOf[T#InternalType])
val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]]
val stats = columnStats.collectedStatistics
- assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0))
- assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1))
- assertResult(10, "Wrong null count")(stats.get(2))
- assertResult(20, "Wrong row count")(stats.get(3))
- assertResult(stats.get(4), "Wrong size in bytes") {
+ assertResult(values.min(ordering), "Wrong lower bound")(stats.genericGet(0))
+ assertResult(values.max(ordering), "Wrong upper bound")(stats.genericGet(1))
+ assertResult(10, "Wrong null count")(stats.genericGet(2))
+ assertResult(20, "Wrong row count")(stats.genericGet(3))
+ assertResult(stats.genericGet(4), "Wrong size in bytes") {
rows.map { row =>
if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
}.sum
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index 4d46a65705..8f024690ef 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -32,13 +32,15 @@ import org.apache.spark.unsafe.types.UTF8String
class ColumnTypeSuite extends SparkFunSuite with Logging {
- val DEFAULT_BUFFER_SIZE = 512
+ private val DEFAULT_BUFFER_SIZE = 512
+ private val MAP_GENERIC = GENERIC(MapType(IntegerType, StringType))
test("defaultSize") {
val checks = Map(
BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, DATE -> 4,
LONG -> 8, TIMESTAMP -> 8, FLOAT -> 4, DOUBLE -> 8,
- STRING -> 8, BINARY -> 16, FIXED_DECIMAL(15, 10) -> 8, GENERIC -> 16)
+ STRING -> 8, BINARY -> 16, FIXED_DECIMAL(15, 10) -> 8,
+ MAP_GENERIC -> 16)
checks.foreach { case (columnType, expectedSize) =>
assertResult(expectedSize, s"Wrong defaultSize for $columnType") {
@@ -48,8 +50,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
}
test("actualSize") {
- def checkActualSize[T <: DataType, JvmType](
- columnType: ColumnType[T, JvmType],
+ def checkActualSize[JvmType](
+ columnType: ColumnType[JvmType],
value: JvmType,
expected: Int): Unit = {
@@ -74,7 +76,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
val generic = Map(1 -> "a")
- checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8)
+ checkActualSize(MAP_GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8)
}
testNativeColumnType(BOOLEAN)(
@@ -123,7 +125,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
UTF8String.fromBytes(bytes)
})
- testColumnType[BinaryType.type, Array[Byte]](
+ testColumnType[Array[Byte]](
BINARY,
(buffer: ByteBuffer, bytes: Array[Byte]) => {
buffer.putInt(bytes.length).put(bytes)
@@ -140,7 +142,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
val obj = Map(1 -> "spark", 2 -> "sql")
val serializedObj = SparkSqlSerializer.serialize(obj)
- GENERIC.append(SparkSqlSerializer.serialize(obj), buffer)
+ MAP_GENERIC.append(SparkSqlSerializer.serialize(obj), buffer)
buffer.rewind()
val length = buffer.getInt()
@@ -157,7 +159,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
assertResult(obj, "Deserialized object didn't equal to the original object") {
buffer.rewind()
- SparkSqlSerializer.deserialize(GENERIC.extract(buffer))
+ SparkSqlSerializer.deserialize(MAP_GENERIC.extract(buffer))
}
}
@@ -170,7 +172,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
val obj = CustomClass(Int.MaxValue, Long.MaxValue)
val serializedObj = serializer.serialize(obj).array()
- GENERIC.append(serializer.serialize(obj).array(), buffer)
+ MAP_GENERIC.append(serializer.serialize(obj).array(), buffer)
buffer.rewind()
val length = buffer.getInt
@@ -192,7 +194,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
assertResult(obj, "Custom deserialized object didn't equal the original object") {
buffer.rewind()
- serializer.deserialize(ByteBuffer.wrap(GENERIC.extract(buffer)))
+ serializer.deserialize(ByteBuffer.wrap(MAP_GENERIC.extract(buffer)))
}
}
@@ -201,11 +203,11 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
(putter: (ByteBuffer, T#InternalType) => Unit,
getter: (ByteBuffer) => T#InternalType): Unit = {
- testColumnType[T, T#InternalType](columnType, putter, getter)
+ testColumnType[T#InternalType](columnType, putter, getter)
}
- def testColumnType[T <: DataType, JvmType](
- columnType: ColumnType[T, JvmType],
+ def testColumnType[JvmType](
+ columnType: ColumnType[JvmType],
putter: (ByteBuffer, JvmType) => Unit,
getter: (ByteBuffer) => JvmType): Unit = {
@@ -262,7 +264,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
}
}
- assertResult(GENERIC) {
+ assertResult(GENERIC(DecimalType(19, 0))) {
ColumnType(DecimalType(19, 0))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
index d986133973..79bb7d072f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
@@ -31,7 +31,7 @@ object ColumnarTestUtils {
row
}
- def makeRandomValue[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]): JvmType = {
+ def makeRandomValue[JvmType](columnType: ColumnType[JvmType]): JvmType = {
def randomBytes(length: Int) = {
val bytes = new Array[Byte](length)
Random.nextBytes(bytes)
@@ -58,15 +58,15 @@ object ColumnarTestUtils {
}
def makeRandomValues(
- head: ColumnType[_ <: DataType, _],
- tail: ColumnType[_ <: DataType, _]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail)
+ head: ColumnType[_],
+ tail: ColumnType[_]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail)
- def makeRandomValues(columnTypes: Seq[ColumnType[_ <: DataType, _]]): Seq[Any] = {
+ def makeRandomValues(columnTypes: Seq[ColumnType[_]]): Seq[Any] = {
columnTypes.map(makeRandomValue(_))
}
- def makeUniqueRandomValues[T <: DataType, JvmType](
- columnType: ColumnType[T, JvmType],
+ def makeUniqueRandomValues[JvmType](
+ columnType: ColumnType[JvmType],
count: Int): Seq[JvmType] = {
Iterator.iterate(HashSet.empty[JvmType]) { set =>
@@ -75,10 +75,10 @@ object ColumnarTestUtils {
}
def makeRandomRow(
- head: ColumnType[_ <: DataType, _],
- tail: ColumnType[_ <: DataType, _]*): InternalRow = makeRandomRow(Seq(head) ++ tail)
+ head: ColumnType[_],
+ tail: ColumnType[_]*): InternalRow = makeRandomRow(Seq(head) ++ tail)
- def makeRandomRow(columnTypes: Seq[ColumnType[_ <: DataType, _]]): InternalRow = {
+ def makeRandomRow(columnTypes: Seq[ColumnType[_]]): InternalRow = {
val row = new GenericMutableRow(columnTypes.length)
makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) =>
row(index) = value
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
index d421f4d8d0..f4f6c7649b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
@@ -21,17 +21,17 @@ import java.nio.ByteBuffer
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types.{StringType, ArrayType, DataType}
-class TestNullableColumnAccessor[T <: DataType, JvmType](
+class TestNullableColumnAccessor[JvmType](
buffer: ByteBuffer,
- columnType: ColumnType[T, JvmType])
+ columnType: ColumnType[JvmType])
extends BasicColumnAccessor(buffer, columnType)
with NullableColumnAccessor
object TestNullableColumnAccessor {
- def apply[T <: DataType, JvmType](buffer: ByteBuffer, columnType: ColumnType[T, JvmType])
- : TestNullableColumnAccessor[T, JvmType] = {
+ def apply[JvmType](buffer: ByteBuffer, columnType: ColumnType[JvmType])
+ : TestNullableColumnAccessor[JvmType] = {
// Skips the column type ID
buffer.getInt()
new TestNullableColumnAccessor(buffer, columnType)
@@ -43,13 +43,13 @@ class NullableColumnAccessorSuite extends SparkFunSuite {
Seq(
BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE,
- STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC)
+ STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType)))
.foreach {
testNullableColumnAccessor(_)
}
- def testNullableColumnAccessor[T <: DataType, JvmType](
- columnType: ColumnType[T, JvmType]): Unit = {
+ def testNullableColumnAccessor[JvmType](
+ columnType: ColumnType[JvmType]): Unit = {
val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
val nullRow = makeNullRow(1)
@@ -75,7 +75,7 @@ class NullableColumnAccessorSuite extends SparkFunSuite {
(0 until 4).foreach { _ =>
assert(accessor.hasNext)
accessor.extractTo(row, 0)
- assert(row.get(0) === randomRow.get(0))
+ assert(row.get(0, columnType.dataType) === randomRow.get(0, columnType.dataType))
assert(accessor.hasNext)
accessor.extractTo(row, 0)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
index cd8bf75ff1..241d09ea20 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
@@ -21,13 +21,13 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.types._
-class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType])
- extends BasicColumnBuilder[T, JvmType](new NoopColumnStats, columnType)
+class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType])
+ extends BasicColumnBuilder[JvmType](new NoopColumnStats, columnType)
with NullableColumnBuilder
object TestNullableColumnBuilder {
- def apply[T <: DataType, JvmType](columnType: ColumnType[T, JvmType], initialSize: Int = 0)
- : TestNullableColumnBuilder[T, JvmType] = {
+ def apply[JvmType](columnType: ColumnType[JvmType], initialSize: Int = 0)
+ : TestNullableColumnBuilder[JvmType] = {
val builder = new TestNullableColumnBuilder(columnType)
builder.initialize(initialSize)
builder
@@ -39,13 +39,13 @@ class NullableColumnBuilderSuite extends SparkFunSuite {
Seq(
BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE,
- STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC)
+ STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType)))
.foreach {
testNullableColumnBuilder(_)
}
- def testNullableColumnBuilder[T <: DataType, JvmType](
- columnType: ColumnType[T, JvmType]): Unit = {
+ def testNullableColumnBuilder[JvmType](
+ columnType: ColumnType[JvmType]): Unit = {
val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
@@ -92,13 +92,14 @@ class NullableColumnBuilderSuite extends SparkFunSuite {
// For non-null values
(0 until 4).foreach { _ =>
- val actual = if (columnType == GENERIC) {
- SparkSqlSerializer.deserialize[Any](GENERIC.extract(buffer))
+ val actual = if (columnType.isInstanceOf[GENERIC]) {
+ SparkSqlSerializer.deserialize[Any](columnType.extract(buffer).asInstanceOf[Array[Byte]])
} else {
columnType.extract(buffer)
}
- assert(actual === randomRow.get(0), "Extracted value didn't equal to the original one")
+ assert(actual === randomRow.get(0, columnType.dataType),
+ "Extracted value didn't equal to the original one")
}
assert(!buffer.hasRemaining)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
index 33092c83a1..9a2948c59b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
@@ -33,7 +33,7 @@ class BooleanBitSetSuite extends SparkFunSuite {
val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet)
val rows = Seq.fill[InternalRow](count)(makeRandomRow(BOOLEAN))
- val values = rows.map(_.get(0))
+ val values = rows.map(_.getBoolean(0))
rows.foreach(builder.appendFrom(_, 0))
val buffer = builder.build()