diff options
author | Wenchen Fan <cloud0fan@outlook.com> | 2015-07-27 13:40:50 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-07-27 13:40:50 -0700 |
commit | 3ab7525dceeb1c2f3c21efb1ee5a9c8bb0fd0c13 (patch) | |
tree | bb8f0c945cac9177aba34eb87b6318756f345f2b | |
parent | 8e7d2bee23dad1535846dae2dc31e35058db16cd (diff) | |
download | spark-3ab7525dceeb1c2f3c21efb1ee5a9c8bb0fd0c13.tar.gz spark-3ab7525dceeb1c2f3c21efb1ee5a9c8bb0fd0c13.tar.bz2 spark-3ab7525dceeb1c2f3c21efb1ee5a9c8bb0fd0c13.zip |
[SPARK-9355][SQL] Remove InternalRow.get generic getter call in columnar cache code
Author: Wenchen Fan <cloud0fan@outlook.com>
Closes #7673 from cloud-fan/row-generic-getter-columnar and squashes the following commits:
88b1170 [Wenchen Fan] fix style
eeae712 [Wenchen Fan] Remove Internal.get generic getter call in columnar cache code
12 files changed, 107 insertions, 95 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index 931469bed6..4c29a09321 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -41,9 +41,9 @@ private[sql] trait ColumnAccessor { protected def underlyingBuffer: ByteBuffer } -private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType]( +private[sql] abstract class BasicColumnAccessor[JvmType]( protected val buffer: ByteBuffer, - protected val columnType: ColumnType[T, JvmType]) + protected val columnType: ColumnType[JvmType]) extends ColumnAccessor { protected def initialize() {} @@ -93,14 +93,14 @@ private[sql] class StringColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, STRING) private[sql] class BinaryColumnAccessor(buffer: ByteBuffer) - extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer, BINARY) + extends BasicColumnAccessor[Array[Byte]](buffer, BINARY) with NullableColumnAccessor private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int) extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale)) -private[sql] class GenericColumnAccessor(buffer: ByteBuffer) - extends BasicColumnAccessor[DataType, Array[Byte]](buffer, GENERIC) +private[sql] class GenericColumnAccessor(buffer: ByteBuffer, dataType: DataType) + extends BasicColumnAccessor[Array[Byte]](buffer, GENERIC(dataType)) with NullableColumnAccessor private[sql] class DateColumnAccessor(buffer: ByteBuffer) @@ -131,7 +131,7 @@ private[sql] object ColumnAccessor { case BinaryType => new BinaryColumnAccessor(dup) case DecimalType.Fixed(precision, scale) if precision < 19 => new FixedDecimalColumnAccessor(dup, precision, scale) - case _ => new GenericColumnAccessor(dup) + case other => new GenericColumnAccessor(dup, other) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 087c522397..454b7b91a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -46,9 +46,9 @@ private[sql] trait ColumnBuilder { def build(): ByteBuffer } -private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( +private[sql] class BasicColumnBuilder[JvmType]( val columnStats: ColumnStats, - val columnType: ColumnType[T, JvmType]) + val columnType: ColumnType[JvmType]) extends ColumnBuilder { protected var columnName: String = _ @@ -78,16 +78,16 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( } } -private[sql] abstract class ComplexColumnBuilder[T <: DataType, JvmType]( +private[sql] abstract class ComplexColumnBuilder[JvmType]( columnStats: ColumnStats, - columnType: ColumnType[T, JvmType]) - extends BasicColumnBuilder[T, JvmType](columnStats, columnType) + columnType: ColumnType[JvmType]) + extends BasicColumnBuilder[JvmType](columnStats, columnType) with NullableColumnBuilder private[sql] abstract class NativeColumnBuilder[T <: AtomicType]( override val columnStats: ColumnStats, override val columnType: NativeColumnType[T]) - extends BasicColumnBuilder[T, T#InternalType](columnStats, columnType) + extends BasicColumnBuilder[T#InternalType](columnStats, columnType) with NullableColumnBuilder with AllCompressionSchemes with CompressibleColumnBuilder[T] @@ -118,8 +118,8 @@ private[sql] class FixedDecimalColumnBuilder( FIXED_DECIMAL(precision, scale)) // TODO (lian) Add support for array, struct and map -private[sql] class GenericColumnBuilder - extends ComplexColumnBuilder(new GenericColumnStats, GENERIC) +private[sql] class GenericColumnBuilder(dataType: DataType) + extends ComplexColumnBuilder(new GenericColumnStats(dataType), GENERIC(dataType)) private[sql] class DateColumnBuilder extends NativeColumnBuilder(new DateColumnStats, DATE) @@ -164,7 +164,7 @@ private[sql] object ColumnBuilder { case BinaryType => new BinaryColumnBuilder case DecimalType.Fixed(precision, scale) if precision < 19 => new FixedDecimalColumnBuilder(precision, scale) - case _ => new GenericColumnBuilder + case other => new GenericColumnBuilder(other) } builder.initialize(initialSize, columnName, useCompression) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 7c63179af6..32a84b2676 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -252,11 +252,13 @@ private[sql] class FixedDecimalColumnStats extends ColumnStats { InternalRow(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class GenericColumnStats extends ColumnStats { +private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats { + val columnType = GENERIC(dataType) + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - sizeInBytes += GENERIC.actualSize(row, ordinal) + sizeInBytes += columnType.actualSize(row, ordinal) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index c0ca52751b..2863f6c230 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -31,14 +31,18 @@ import org.apache.spark.unsafe.types.UTF8String * An abstract class that represents type of a column. Used to append/extract Java objects into/from * the underlying [[ByteBuffer]] of a column. * - * @param typeId A unique ID representing the type. - * @param defaultSize Default size in bytes for one element of type T (e.g. 4 for `Int`). - * @tparam T Scala data type for the column. * @tparam JvmType Underlying Java type to represent the elements. */ -private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( - val typeId: Int, - val defaultSize: Int) { +private[sql] sealed abstract class ColumnType[JvmType] { + + // The catalyst data type of this column. + def dataType: DataType + + // A unique ID representing the type. + def typeId: Int + + // Default size in bytes for one element of type T (e.g. 4 for `Int`). + def defaultSize: Int /** * Extracts a value out of the buffer at the buffer's current position. @@ -90,7 +94,7 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( * boxing/unboxing costs whenever possible. */ def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { - to(toOrdinal) = from.get(fromOrdinal) + to.update(toOrdinal, from.get(fromOrdinal, dataType)) } /** @@ -103,9 +107,9 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( private[sql] abstract class NativeColumnType[T <: AtomicType]( val dataType: T, - typeId: Int, - defaultSize: Int) - extends ColumnType[T, T#InternalType](typeId, defaultSize) { + val typeId: Int, + val defaultSize: Int) + extends ColumnType[T#InternalType] { /** * Scala TypeTag. Can be used to create primitive arrays and hash tables. @@ -400,10 +404,10 @@ private[sql] object FIXED_DECIMAL { val defaultSize = 8 } -private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( - typeId: Int, - defaultSize: Int) - extends ColumnType[T, Array[Byte]](typeId, defaultSize) { +private[sql] sealed abstract class ByteArrayColumnType( + val typeId: Int, + val defaultSize: Int) + extends ColumnType[Array[Byte]] { override def actualSize(row: InternalRow, ordinal: Int): Int = { getField(row, ordinal).length + 4 @@ -421,9 +425,12 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( } } -private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](11, 16) { +private[sql] object BINARY extends ByteArrayColumnType(11, 16) { + + def dataType: DataType = BooleanType + override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { - row(ordinal) = value + row.update(ordinal, value) } override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { @@ -434,18 +441,18 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](11, 16) // Used to process generic objects (all types other than those listed above). Objects should be // serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized // byte array. -private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) { +private[sql] case class GENERIC(dataType: DataType) extends ByteArrayColumnType(12, 16) { override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { - row(ordinal) = SparkSqlSerializer.deserialize[Any](value) + row.update(ordinal, SparkSqlSerializer.deserialize[Any](value)) } override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { - SparkSqlSerializer.serialize(row.get(ordinal)) + SparkSqlSerializer.serialize(row.get(ordinal, dataType)) } } private[sql] object ColumnType { - def apply(dataType: DataType): ColumnType[_, _] = { + def apply(dataType: DataType): ColumnType[_] = { dataType match { case BooleanType => BOOLEAN case ByteType => BYTE @@ -460,7 +467,7 @@ private[sql] object ColumnType { case BinaryType => BINARY case DecimalType.Fixed(precision, scale) if precision < 19 => FIXED_DECIMAL(precision, scale) - case _ => GENERIC + case other => GENERIC(other) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala index 4eaec6d853..b1ef9b2ef7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala @@ -46,7 +46,7 @@ private[sql] trait Decoder[T <: AtomicType] { private[sql] trait CompressionScheme { def typeId: Int - def supports(columnType: ColumnType[_, _]): Boolean + def supports(columnType: ColumnType[_]): Boolean def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala index 6150df6930..c91d960a09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.Utils private[sql] case object PassThrough extends CompressionScheme { override val typeId = 0 - override def supports(columnType: ColumnType[_, _]): Boolean = true + override def supports(columnType: ColumnType[_]): Boolean = true override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = { new this.Encoder[T](columnType) @@ -78,7 +78,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { new this.Decoder(buffer, columnType) } - override def supports(columnType: ColumnType[_, _]): Boolean = columnType match { + override def supports(columnType: ColumnType[_]): Boolean = columnType match { case INT | LONG | SHORT | BYTE | STRING | BOOLEAN => true case _ => false } @@ -128,7 +128,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { while (from.hasRemaining) { columnType.extract(from, value, 0) - if (value.get(0) == currentValue.get(0)) { + if (value.get(0, columnType.dataType) == currentValue.get(0, columnType.dataType)) { currentRun += 1 } else { // Writes current run @@ -189,7 +189,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { new this.Encoder[T](columnType) } - override def supports(columnType: ColumnType[_, _]): Boolean = columnType match { + override def supports(columnType: ColumnType[_]): Boolean = columnType match { case INT | LONG | STRING => true case _ => false } @@ -304,7 +304,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme { (new this.Encoder).asInstanceOf[compression.Encoder[T]] } - override def supports(columnType: ColumnType[_, _]): Boolean = columnType == BOOLEAN + override def supports(columnType: ColumnType[_]): Boolean = columnType == BOOLEAN class Encoder extends compression.Encoder[BooleanType.type] { private var _uncompressedSize = 0 @@ -392,7 +392,7 @@ private[sql] case object IntDelta extends CompressionScheme { (new Encoder).asInstanceOf[compression.Encoder[T]] } - override def supports(columnType: ColumnType[_, _]): Boolean = columnType == INT + override def supports(columnType: ColumnType[_]): Boolean = columnType == INT class Encoder extends compression.Encoder[IntegerType.type] { protected var _compressedSize: Int = 0 @@ -472,7 +472,7 @@ private[sql] case object LongDelta extends CompressionScheme { (new Encoder).asInstanceOf[compression.Encoder[T]] } - override def supports(columnType: ColumnType[_, _]): Boolean = columnType == LONG + override def supports(columnType: ColumnType[_]): Boolean = columnType == LONG class Encoder extends compression.Encoder[LongType.type] { protected var _compressedSize: Int = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 31e7b0e72e..4499a72070 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -58,15 +58,15 @@ class ColumnStatsSuite extends SparkFunSuite { val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) - val values = rows.take(10).map(_.get(0).asInstanceOf[T#InternalType]) + val values = rows.take(10).map(_.get(0, columnType.dataType).asInstanceOf[T#InternalType]) val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1)) - assertResult(10, "Wrong null count")(stats.get(2)) - assertResult(20, "Wrong row count")(stats.get(3)) - assertResult(stats.get(4), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.genericGet(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.genericGet(1)) + assertResult(10, "Wrong null count")(stats.genericGet(2)) + assertResult(20, "Wrong row count")(stats.genericGet(3)) + assertResult(stats.genericGet(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 4d46a65705..8f024690ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -32,13 +32,15 @@ import org.apache.spark.unsafe.types.UTF8String class ColumnTypeSuite extends SparkFunSuite with Logging { - val DEFAULT_BUFFER_SIZE = 512 + private val DEFAULT_BUFFER_SIZE = 512 + private val MAP_GENERIC = GENERIC(MapType(IntegerType, StringType)) test("defaultSize") { val checks = Map( BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, DATE -> 4, LONG -> 8, TIMESTAMP -> 8, FLOAT -> 4, DOUBLE -> 8, - STRING -> 8, BINARY -> 16, FIXED_DECIMAL(15, 10) -> 8, GENERIC -> 16) + STRING -> 8, BINARY -> 16, FIXED_DECIMAL(15, 10) -> 8, + MAP_GENERIC -> 16) checks.foreach { case (columnType, expectedSize) => assertResult(expectedSize, s"Wrong defaultSize for $columnType") { @@ -48,8 +50,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { } test("actualSize") { - def checkActualSize[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType], + def checkActualSize[JvmType]( + columnType: ColumnType[JvmType], value: JvmType, expected: Int): Unit = { @@ -74,7 +76,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8) val generic = Map(1 -> "a") - checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) + checkActualSize(MAP_GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) } testNativeColumnType(BOOLEAN)( @@ -123,7 +125,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { UTF8String.fromBytes(bytes) }) - testColumnType[BinaryType.type, Array[Byte]]( + testColumnType[Array[Byte]]( BINARY, (buffer: ByteBuffer, bytes: Array[Byte]) => { buffer.putInt(bytes.length).put(bytes) @@ -140,7 +142,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { val obj = Map(1 -> "spark", 2 -> "sql") val serializedObj = SparkSqlSerializer.serialize(obj) - GENERIC.append(SparkSqlSerializer.serialize(obj), buffer) + MAP_GENERIC.append(SparkSqlSerializer.serialize(obj), buffer) buffer.rewind() val length = buffer.getInt() @@ -157,7 +159,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { assertResult(obj, "Deserialized object didn't equal to the original object") { buffer.rewind() - SparkSqlSerializer.deserialize(GENERIC.extract(buffer)) + SparkSqlSerializer.deserialize(MAP_GENERIC.extract(buffer)) } } @@ -170,7 +172,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { val obj = CustomClass(Int.MaxValue, Long.MaxValue) val serializedObj = serializer.serialize(obj).array() - GENERIC.append(serializer.serialize(obj).array(), buffer) + MAP_GENERIC.append(serializer.serialize(obj).array(), buffer) buffer.rewind() val length = buffer.getInt @@ -192,7 +194,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { assertResult(obj, "Custom deserialized object didn't equal the original object") { buffer.rewind() - serializer.deserialize(ByteBuffer.wrap(GENERIC.extract(buffer))) + serializer.deserialize(ByteBuffer.wrap(MAP_GENERIC.extract(buffer))) } } @@ -201,11 +203,11 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { (putter: (ByteBuffer, T#InternalType) => Unit, getter: (ByteBuffer) => T#InternalType): Unit = { - testColumnType[T, T#InternalType](columnType, putter, getter) + testColumnType[T#InternalType](columnType, putter, getter) } - def testColumnType[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType], + def testColumnType[JvmType]( + columnType: ColumnType[JvmType], putter: (ByteBuffer, JvmType) => Unit, getter: (ByteBuffer) => JvmType): Unit = { @@ -262,7 +264,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { } } - assertResult(GENERIC) { + assertResult(GENERIC(DecimalType(19, 0))) { ColumnType(DecimalType(19, 0)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index d986133973..79bb7d072f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -31,7 +31,7 @@ object ColumnarTestUtils { row } - def makeRandomValue[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]): JvmType = { + def makeRandomValue[JvmType](columnType: ColumnType[JvmType]): JvmType = { def randomBytes(length: Int) = { val bytes = new Array[Byte](length) Random.nextBytes(bytes) @@ -58,15 +58,15 @@ object ColumnarTestUtils { } def makeRandomValues( - head: ColumnType[_ <: DataType, _], - tail: ColumnType[_ <: DataType, _]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail) + head: ColumnType[_], + tail: ColumnType[_]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail) - def makeRandomValues(columnTypes: Seq[ColumnType[_ <: DataType, _]]): Seq[Any] = { + def makeRandomValues(columnTypes: Seq[ColumnType[_]]): Seq[Any] = { columnTypes.map(makeRandomValue(_)) } - def makeUniqueRandomValues[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType], + def makeUniqueRandomValues[JvmType]( + columnType: ColumnType[JvmType], count: Int): Seq[JvmType] = { Iterator.iterate(HashSet.empty[JvmType]) { set => @@ -75,10 +75,10 @@ object ColumnarTestUtils { } def makeRandomRow( - head: ColumnType[_ <: DataType, _], - tail: ColumnType[_ <: DataType, _]*): InternalRow = makeRandomRow(Seq(head) ++ tail) + head: ColumnType[_], + tail: ColumnType[_]*): InternalRow = makeRandomRow(Seq(head) ++ tail) - def makeRandomRow(columnTypes: Seq[ColumnType[_ <: DataType, _]]): InternalRow = { + def makeRandomRow(columnTypes: Seq[ColumnType[_]]): InternalRow = { val row = new GenericMutableRow(columnTypes.length) makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) => row(index) = value diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index d421f4d8d0..f4f6c7649b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -21,17 +21,17 @@ import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{StringType, ArrayType, DataType} -class TestNullableColumnAccessor[T <: DataType, JvmType]( +class TestNullableColumnAccessor[JvmType]( buffer: ByteBuffer, - columnType: ColumnType[T, JvmType]) + columnType: ColumnType[JvmType]) extends BasicColumnAccessor(buffer, columnType) with NullableColumnAccessor object TestNullableColumnAccessor { - def apply[T <: DataType, JvmType](buffer: ByteBuffer, columnType: ColumnType[T, JvmType]) - : TestNullableColumnAccessor[T, JvmType] = { + def apply[JvmType](buffer: ByteBuffer, columnType: ColumnType[JvmType]) + : TestNullableColumnAccessor[JvmType] = { // Skips the column type ID buffer.getInt() new TestNullableColumnAccessor(buffer, columnType) @@ -43,13 +43,13 @@ class NullableColumnAccessorSuite extends SparkFunSuite { Seq( BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, - STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC) + STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType))) .foreach { testNullableColumnAccessor(_) } - def testNullableColumnAccessor[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType]): Unit = { + def testNullableColumnAccessor[JvmType]( + columnType: ColumnType[JvmType]): Unit = { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") val nullRow = makeNullRow(1) @@ -75,7 +75,7 @@ class NullableColumnAccessorSuite extends SparkFunSuite { (0 until 4).foreach { _ => assert(accessor.hasNext) accessor.extractTo(row, 0) - assert(row.get(0) === randomRow.get(0)) + assert(row.get(0, columnType.dataType) === randomRow.get(0, columnType.dataType)) assert(accessor.hasNext) accessor.extractTo(row, 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index cd8bf75ff1..241d09ea20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -21,13 +21,13 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ -class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) - extends BasicColumnBuilder[T, JvmType](new NoopColumnStats, columnType) +class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType]) + extends BasicColumnBuilder[JvmType](new NoopColumnStats, columnType) with NullableColumnBuilder object TestNullableColumnBuilder { - def apply[T <: DataType, JvmType](columnType: ColumnType[T, JvmType], initialSize: Int = 0) - : TestNullableColumnBuilder[T, JvmType] = { + def apply[JvmType](columnType: ColumnType[JvmType], initialSize: Int = 0) + : TestNullableColumnBuilder[JvmType] = { val builder = new TestNullableColumnBuilder(columnType) builder.initialize(initialSize) builder @@ -39,13 +39,13 @@ class NullableColumnBuilderSuite extends SparkFunSuite { Seq( BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, - STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC) + STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType))) .foreach { testNullableColumnBuilder(_) } - def testNullableColumnBuilder[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType]): Unit = { + def testNullableColumnBuilder[JvmType]( + columnType: ColumnType[JvmType]): Unit = { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") @@ -92,13 +92,14 @@ class NullableColumnBuilderSuite extends SparkFunSuite { // For non-null values (0 until 4).foreach { _ => - val actual = if (columnType == GENERIC) { - SparkSqlSerializer.deserialize[Any](GENERIC.extract(buffer)) + val actual = if (columnType.isInstanceOf[GENERIC]) { + SparkSqlSerializer.deserialize[Any](columnType.extract(buffer).asInstanceOf[Array[Byte]]) } else { columnType.extract(buffer) } - assert(actual === randomRow.get(0), "Extracted value didn't equal to the original one") + assert(actual === randomRow.get(0, columnType.dataType), + "Extracted value didn't equal to the original one") } assert(!buffer.hasRemaining) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala index 33092c83a1..9a2948c59b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala @@ -33,7 +33,7 @@ class BooleanBitSetSuite extends SparkFunSuite { val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet) val rows = Seq.fill[InternalRow](count)(makeRandomRow(BOOLEAN)) - val values = rows.map(_.get(0)) + val values = rows.map(_.getBoolean(0)) rows.foreach(builder.appendFrom(_, 0)) val buffer = builder.build() |