aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-10-07 15:58:07 -0700
committerDavies Liu <davies.liu@gmail.com>2015-10-07 15:58:07 -0700
commit075a0b658289608c8732e07e26e14d736e673ce9 (patch)
tree91ab61c1f6cf7d9284c00f4e35037da7721c812a
parentdd36ec6bc5844aaa045a4bd9ba49113528e1740c (diff)
downloadspark-075a0b658289608c8732e07e26e14d736e673ce9.tar.gz
spark-075a0b658289608c8732e07e26e14d736e673ce9.tar.bz2
spark-075a0b658289608c8732e07e26e14d736e673ce9.zip
[SPARK-10917] [SQL] improve performance of complex type in columnar cache
This PR improve the performance of complex types in columnar cache by using UnsafeProjection instead of KryoSerializer. A simple benchmark show that this PR could improve the performance of scanning a cached table with complex columns by 15x (comparing to Spark 1.5). Here is the code used to benchmark: ``` df = sc.range(1<<23).map(lambda i: Row(a=Row(b=i, c=str(i)), d=range(10), e=dict(zip(range(10), [str(i) for i in range(10)])))).toDF() df.write.parquet("table") ``` ``` df = sqlContext.read.parquet("table") df.cache() df.count() t = time.time() print df.select("*")._jdf.queryExecution().toRdd().count() print time.time() - t ``` Author: Davies Liu <davies@databricks.com> Closes #8971 from davies/complex.
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala38
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala40
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala219
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala235
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala21
12 files changed, 352 insertions, 266 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index 6a16d34083..fdd9125613 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -23,7 +23,6 @@ import java.math.BigInteger;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
-import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala
index f6fa021ade..52069598ee 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala
@@ -48,6 +48,11 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte
}
object ArrayBasedMapData {
+ def apply(map: Map[Any, Any]): ArrayBasedMapData = {
+ val array = map.toArray
+ ArrayBasedMapData(array.map(_._1), array.map(_._2))
+ }
+
def apply(keys: Array[Any], values: Array[Any]): ArrayBasedMapData = {
new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
index 2b1d700987..f04099f54c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.columnar
import java.nio.{ByteBuffer, ByteOrder}
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.MutableRow
import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor
import org.apache.spark.sql.types._
@@ -61,6 +62,10 @@ private[sql] abstract class BasicColumnAccessor[JvmType](
protected def underlyingBuffer = buffer
}
+private[sql] class NullColumnAccess(buffer: ByteBuffer)
+ extends BasicColumnAccessor[Any](buffer, NULL)
+ with NullableColumnAccessor
+
private[sql] abstract class NativeColumnAccessor[T <: AtomicType](
override protected val buffer: ByteBuffer,
override protected val columnType: NativeColumnType[T])
@@ -96,11 +101,23 @@ private[sql] class BinaryColumnAccessor(buffer: ByteBuffer)
extends BasicColumnAccessor[Array[Byte]](buffer, BINARY)
with NullableColumnAccessor
-private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int)
- extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale))
+private[sql] class CompactDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType)
+ extends NativeColumnAccessor(buffer, COMPACT_DECIMAL(dataType))
+
+private[sql] class DecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType)
+ extends BasicColumnAccessor[Decimal](buffer, LARGE_DECIMAL(dataType))
+ with NullableColumnAccessor
+
+private[sql] class StructColumnAccessor(buffer: ByteBuffer, dataType: StructType)
+ extends BasicColumnAccessor[InternalRow](buffer, STRUCT(dataType))
+ with NullableColumnAccessor
+
+private[sql] class ArrayColumnAccessor(buffer: ByteBuffer, dataType: ArrayType)
+ extends BasicColumnAccessor[ArrayData](buffer, ARRAY(dataType))
+ with NullableColumnAccessor
-private[sql] class GenericColumnAccessor(buffer: ByteBuffer, dataType: DataType)
- extends BasicColumnAccessor[Array[Byte]](buffer, GENERIC(dataType))
+private[sql] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType)
+ extends BasicColumnAccessor[MapData](buffer, MAP(dataType))
with NullableColumnAccessor
private[sql] object ColumnAccessor {
@@ -108,6 +125,7 @@ private[sql] object ColumnAccessor {
val buf = buffer.order(ByteOrder.nativeOrder)
dataType match {
+ case NullType => new NullColumnAccess(buf)
case BooleanType => new BooleanColumnAccessor(buf)
case ByteType => new ByteColumnAccessor(buf)
case ShortType => new ShortColumnAccessor(buf)
@@ -117,9 +135,15 @@ private[sql] object ColumnAccessor {
case DoubleType => new DoubleColumnAccessor(buf)
case StringType => new StringColumnAccessor(buf)
case BinaryType => new BinaryColumnAccessor(buf)
- case DecimalType.Fixed(precision, scale) if precision < 19 =>
- new FixedDecimalColumnAccessor(buf, precision, scale)
- case other => new GenericColumnAccessor(buf, other)
+ case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
+ new CompactDecimalColumnAccessor(buf, dt)
+ case dt: DecimalType => new DecimalColumnAccessor(buf, dt)
+ case struct: StructType => new StructColumnAccessor(buf, struct)
+ case array: ArrayType => new ArrayColumnAccessor(buf, array)
+ case map: MapType => new MapColumnAccessor(buf, map)
+ case udt: UserDefinedType[_] => ColumnAccessor(udt.sqlType, buffer)
+ case other =>
+ throw new Exception(s"not support type: $other")
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
index 2e60564f7c..7a7345a7e0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
@@ -77,6 +77,10 @@ private[sql] class BasicColumnBuilder[JvmType](
}
}
+private[sql] class NullColumnBuilder
+ extends BasicColumnBuilder[Any](new ObjectColumnStats(NullType), NULL)
+ with NullableColumnBuilder
+
private[sql] abstract class ComplexColumnBuilder[JvmType](
columnStats: ColumnStats,
columnType: ColumnType[JvmType])
@@ -109,16 +113,20 @@ private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringCol
private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY)
-private[sql] class FixedDecimalColumnBuilder(
- precision: Int,
- scale: Int)
- extends NativeColumnBuilder(
- new FixedDecimalColumnStats(precision, scale),
- FIXED_DECIMAL(precision, scale))
+private[sql] class CompactDecimalColumnBuilder(dataType: DecimalType)
+ extends NativeColumnBuilder(new DecimalColumnStats(dataType), COMPACT_DECIMAL(dataType))
+
+private[sql] class DecimalColumnBuilder(dataType: DecimalType)
+ extends ComplexColumnBuilder(new DecimalColumnStats(dataType), LARGE_DECIMAL(dataType))
+
+private[sql] class StructColumnBuilder(dataType: StructType)
+ extends ComplexColumnBuilder(new ObjectColumnStats(dataType), STRUCT(dataType))
+
+private[sql] class ArrayColumnBuilder(dataType: ArrayType)
+ extends ComplexColumnBuilder(new ObjectColumnStats(dataType), ARRAY(dataType))
-// TODO (lian) Add support for array, struct and map
-private[sql] class GenericColumnBuilder(dataType: DataType)
- extends ComplexColumnBuilder(new GenericColumnStats(dataType), GENERIC(dataType))
+private[sql] class MapColumnBuilder(dataType: MapType)
+ extends ComplexColumnBuilder(new ObjectColumnStats(dataType), MAP(dataType))
private[sql] object ColumnBuilder {
val DEFAULT_INITIAL_BUFFER_SIZE = 1024 * 1024
@@ -145,6 +153,7 @@ private[sql] object ColumnBuilder {
columnName: String = "",
useCompression: Boolean = false): ColumnBuilder = {
val builder: ColumnBuilder = dataType match {
+ case NullType => new NullColumnBuilder
case BooleanType => new BooleanColumnBuilder
case ByteType => new ByteColumnBuilder
case ShortType => new ShortColumnBuilder
@@ -154,9 +163,16 @@ private[sql] object ColumnBuilder {
case DoubleType => new DoubleColumnBuilder
case StringType => new StringColumnBuilder
case BinaryType => new BinaryColumnBuilder
- case DecimalType.Fixed(precision, scale) if precision < 19 =>
- new FixedDecimalColumnBuilder(precision, scale)
- case other => new GenericColumnBuilder(other)
+ case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
+ new CompactDecimalColumnBuilder(dt)
+ case dt: DecimalType => new DecimalColumnBuilder(dt)
+ case struct: StructType => new StructColumnBuilder(struct)
+ case array: ArrayType => new ArrayColumnBuilder(array)
+ case map: MapType => new MapColumnBuilder(map)
+ case udt: UserDefinedType[_] =>
+ return apply(udt.sqlType, initialSize, columnName, useCompression)
+ case other =>
+ throw new Exception(s"not suppported type: $other")
}
builder.initialize(initialSize, columnName, useCompression)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
index 3b5052b754..ba61003ba4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -235,7 +235,9 @@ private[sql] class BinaryColumnStats extends ColumnStats {
new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes))
}
-private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends ColumnStats {
+private[sql] class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats {
+ def this(dt: DecimalType) = this(dt.precision, dt.scale)
+
protected var upper: Decimal = null
protected var lower: Decimal = null
@@ -245,7 +247,8 @@ private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends C
val value = row.getDecimal(ordinal, precision, scale)
if (upper == null || value.compareTo(upper) > 0) upper = value
if (lower == null || value.compareTo(lower) < 0) lower = value
- sizeInBytes += FIXED_DECIMAL.defaultSize
+ // TODO: this is not right for DecimalType with precision > 18
+ sizeInBytes += 8
}
}
@@ -253,8 +256,8 @@ private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends C
new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
}
-private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats {
- val columnType = GENERIC(dataType)
+private[sql] class ObjectColumnStats(dataType: DataType) extends ColumnStats {
+ val columnType = ColumnType(dataType)
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index 3a0cea8750..3563eacb3a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -17,14 +17,15 @@
package org.apache.spark.sql.columnar
-import java.nio.ByteBuffer
+import java.math.{BigDecimal, BigInteger}
+import java.nio.{ByteOrder, ByteBuffer}
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.MutableRow
-import org.apache.spark.sql.execution.SparkSqlSerializer
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types.UTF8String
/**
@@ -102,6 +103,16 @@ private[sql] sealed abstract class ColumnType[JvmType] {
override def toString: String = getClass.getSimpleName.stripSuffix("$")
}
+private[sql] object NULL extends ColumnType[Any] {
+
+ override def dataType: DataType = NullType
+ override def defaultSize: Int = 0
+ override def append(v: Any, buffer: ByteBuffer): Unit = {}
+ override def extract(buffer: ByteBuffer): Any = null
+ override def setField(row: MutableRow, ordinal: Int, value: Any): Unit = row.setNullAt(ordinal)
+ override def getField(row: InternalRow, ordinal: Int): Any = null
+}
+
private[sql] abstract class NativeColumnType[T <: AtomicType](
val dataType: T,
val defaultSize: Int)
@@ -339,10 +350,8 @@ private[sql] object STRING extends NativeColumnType(StringType, 8) {
override def clone(v: UTF8String): UTF8String = v.clone()
}
-private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int)
- extends NativeColumnType(
- DecimalType(precision, scale),
- FIXED_DECIMAL.defaultSize) {
+private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int)
+ extends NativeColumnType(DecimalType(precision, scale), 8) {
override def extract(buffer: ByteBuffer): Decimal = {
Decimal(buffer.getLong(), precision, scale)
@@ -365,32 +374,39 @@ private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int)
}
}
-private[sql] object FIXED_DECIMAL {
- val defaultSize = 8
+private[sql] object COMPACT_DECIMAL {
+ def apply(dt: DecimalType): COMPACT_DECIMAL = {
+ COMPACT_DECIMAL(dt.precision, dt.scale)
+ }
}
-private[sql] sealed abstract class ByteArrayColumnType(val defaultSize: Int)
- extends ColumnType[Array[Byte]] {
+private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: Int)
+ extends ColumnType[JvmType] {
+
+ def serialize(value: JvmType): Array[Byte]
+ def deserialize(bytes: Array[Byte]): JvmType
override def actualSize(row: InternalRow, ordinal: Int): Int = {
- getField(row, ordinal).length + 4
+ // TODO: grow the buffer in append(), so serialize() will not be called twice
+ serialize(getField(row, ordinal)).length + 4
}
- override def append(v: Array[Byte], buffer: ByteBuffer): Unit = {
- buffer.putInt(v.length).put(v, 0, v.length)
+ override def append(v: JvmType, buffer: ByteBuffer): Unit = {
+ val bytes = serialize(v)
+ buffer.putInt(bytes.length).put(bytes, 0, bytes.length)
}
- override def extract(buffer: ByteBuffer): Array[Byte] = {
+ override def extract(buffer: ByteBuffer): JvmType = {
val length = buffer.getInt()
val bytes = new Array[Byte](length)
buffer.get(bytes, 0, length)
- bytes
+ deserialize(bytes)
}
}
-private[sql] object BINARY extends ByteArrayColumnType(16) {
+private[sql] object BINARY extends ByteArrayColumnType[Array[Byte]](16) {
- def dataType: DataType = BooleanType
+ def dataType: DataType = BinaryType
override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = {
row.update(ordinal, value)
@@ -399,24 +415,164 @@ private[sql] object BINARY extends ByteArrayColumnType(16) {
override def getField(row: InternalRow, ordinal: Int): Array[Byte] = {
row.getBinary(ordinal)
}
+
+ def serialize(value: Array[Byte]): Array[Byte] = value
+ def deserialize(bytes: Array[Byte]): Array[Byte] = bytes
}
-// Used to process generic objects (all types other than those listed above). Objects should be
-// serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized
-// byte array.
-private[sql] case class GENERIC(dataType: DataType) extends ByteArrayColumnType(16) {
- override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = {
- row.update(ordinal, SparkSqlSerializer.deserialize[Any](value))
+private[sql] case class LARGE_DECIMAL(precision: Int, scale: Int)
+ extends ByteArrayColumnType[Decimal](12) {
+
+ override val dataType: DataType = DecimalType(precision, scale)
+
+ override def getField(row: InternalRow, ordinal: Int): Decimal = {
+ row.getDecimal(ordinal, precision, scale)
}
- override def getField(row: InternalRow, ordinal: Int): Array[Byte] = {
- SparkSqlSerializer.serialize(row.get(ordinal, dataType))
+ override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = {
+ row.setDecimal(ordinal, value, precision)
+ }
+
+ override def serialize(value: Decimal): Array[Byte] = {
+ value.toJavaBigDecimal.unscaledValue().toByteArray
+ }
+
+ override def deserialize(bytes: Array[Byte]): Decimal = {
+ val javaDecimal = new BigDecimal(new BigInteger(bytes), scale)
+ Decimal.apply(javaDecimal, precision, scale)
+ }
+}
+
+private[sql] object LARGE_DECIMAL {
+ def apply(dt: DecimalType): LARGE_DECIMAL = {
+ LARGE_DECIMAL(dt.precision, dt.scale)
+ }
+}
+
+private[sql] case class STRUCT(dataType: StructType)
+ extends ByteArrayColumnType[InternalRow](20) {
+
+ private val projection: UnsafeProjection =
+ UnsafeProjection.create(dataType)
+ private val numOfFields: Int = dataType.fields.size
+
+ override def setField(row: MutableRow, ordinal: Int, value: InternalRow): Unit = {
+ row.update(ordinal, value)
+ }
+
+ override def getField(row: InternalRow, ordinal: Int): InternalRow = {
+ row.getStruct(ordinal, numOfFields)
+ }
+
+ override def serialize(value: InternalRow): Array[Byte] = {
+ val unsafeRow = if (value.isInstanceOf[UnsafeRow]) {
+ value.asInstanceOf[UnsafeRow]
+ } else {
+ projection(value)
+ }
+ unsafeRow.getBytes
+ }
+
+ override def deserialize(bytes: Array[Byte]): InternalRow = {
+ val unsafeRow = new UnsafeRow
+ unsafeRow.pointTo(bytes, numOfFields, bytes.length)
+ unsafeRow
+ }
+
+ override def clone(v: InternalRow): InternalRow = v.copy()
+}
+
+private[sql] case class ARRAY(dataType: ArrayType)
+ extends ByteArrayColumnType[ArrayData](16) {
+
+ private lazy val projection = UnsafeProjection.create(Array[DataType](dataType))
+ private val mutableRow = new GenericMutableRow(new Array[Any](1))
+
+ override def setField(row: MutableRow, ordinal: Int, value: ArrayData): Unit = {
+ row.update(ordinal, value)
+ }
+
+ override def getField(row: InternalRow, ordinal: Int): ArrayData = {
+ row.getArray(ordinal)
+ }
+
+ override def serialize(value: ArrayData): Array[Byte] = {
+ val unsafeArray = if (value.isInstanceOf[UnsafeArrayData]) {
+ value.asInstanceOf[UnsafeArrayData]
+ } else {
+ mutableRow(0) = value
+ projection(mutableRow).getArray(0)
+ }
+ val outputBuffer =
+ ByteBuffer.allocate(4 + unsafeArray.getSizeInBytes).order(ByteOrder.nativeOrder())
+ outputBuffer.putInt(unsafeArray.numElements())
+ val underlying = outputBuffer.array()
+ unsafeArray.writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 4)
+ underlying
+ }
+
+ override def deserialize(bytes: Array[Byte]): ArrayData = {
+ val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder())
+ val numElements = buffer.getInt
+ val array = new UnsafeArrayData
+ array.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 4, numElements, bytes.length - 4)
+ array
}
+
+ override def clone(v: ArrayData): ArrayData = v.copy()
+}
+
+private[sql] case class MAP(dataType: MapType) extends ByteArrayColumnType[MapData](32) {
+
+ private lazy val projection: UnsafeProjection = UnsafeProjection.create(Array[DataType](dataType))
+ private val mutableRow = new GenericMutableRow(new Array[Any](1))
+
+ override def setField(row: MutableRow, ordinal: Int, value: MapData): Unit = {
+ row.update(ordinal, value)
+ }
+
+ override def getField(row: InternalRow, ordinal: Int): MapData = {
+ row.getMap(ordinal)
+ }
+
+ override def serialize(value: MapData): Array[Byte] = {
+ val unsafeMap = if (value.isInstanceOf[UnsafeMapData]) {
+ value.asInstanceOf[UnsafeMapData]
+ } else {
+ mutableRow(0) = value
+ projection(mutableRow).getMap(0)
+ }
+
+ val outputBuffer =
+ ByteBuffer.allocate(8 + unsafeMap.getSizeInBytes).order(ByteOrder.nativeOrder())
+ outputBuffer.putInt(unsafeMap.numElements())
+ val keyBytes = unsafeMap.keyArray().getSizeInBytes
+ outputBuffer.putInt(keyBytes)
+ val underlying = outputBuffer.array()
+ unsafeMap.keyArray().writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 8)
+ unsafeMap.valueArray().writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 8 + keyBytes)
+ underlying
+ }
+
+ override def deserialize(bytes: Array[Byte]): MapData = {
+ val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder())
+ val numElements = buffer.getInt
+ val keyArraySize = buffer.getInt
+ val keyArray = new UnsafeArrayData
+ val valueArray = new UnsafeArrayData
+ keyArray.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 8, numElements, keyArraySize)
+ valueArray.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 8 + keyArraySize, numElements,
+ bytes.length - 8 - keyArraySize)
+ new UnsafeMapData(keyArray, valueArray)
+ }
+
+ override def clone(v: MapData): MapData = v.copy()
}
private[sql] object ColumnType {
def apply(dataType: DataType): ColumnType[_] = {
dataType match {
+ case NullType => NULL
case BooleanType => BOOLEAN
case ByteType => BYTE
case ShortType => SHORT
@@ -426,9 +582,14 @@ private[sql] object ColumnType {
case DoubleType => DOUBLE
case StringType => STRING
case BinaryType => BINARY
- case DecimalType.Fixed(precision, scale) if precision < 19 =>
- FIXED_DECIMAL(precision, scale)
- case other => GENERIC(other)
+ case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => COMPACT_DECIMAL(dt)
+ case dt: DecimalType => LARGE_DECIMAL(dt)
+ case arr: ArrayType => ARRAY(arr)
+ case map: MapType => MAP(map)
+ case struct: StructType => STRUCT(struct)
+ case udt: UserDefinedType[_] => apply(udt.sqlType)
+ case other =>
+ throw new Exception(s"Unsupported type: $other")
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index 708fb4cf79..89a664001b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.columnar
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.types._
@@ -76,11 +75,11 @@ class ColumnStatsSuite extends SparkFunSuite {
def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](
initialStatistics: GenericInternalRow): Unit = {
- val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName
- val columnType = FIXED_DECIMAL(15, 10)
+ val columnStatsName = classOf[DecimalColumnStats].getSimpleName
+ val columnType = COMPACT_DECIMAL(15, 10)
test(s"$columnStatsName: empty") {
- val columnStats = new FixedDecimalColumnStats(15, 10)
+ val columnStats = new DecimalColumnStats(15, 10)
columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach {
case (actual, expected) => assert(actual === expected)
}
@@ -89,7 +88,7 @@ class ColumnStatsSuite extends SparkFunSuite {
test(s"$columnStatsName: non-empty") {
import org.apache.spark.sql.columnar.ColumnarTestUtils._
- val columnStats = new FixedDecimalColumnStats(15, 10)
+ val columnStats = new DecimalColumnStats(15, 10)
val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
rows.foreach(columnStats.gatherStats(_, 0))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index a4cbe3525e..ceb8ad97bb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -19,28 +19,25 @@ package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
-import com.esotericsoftware.kryo.io.{Input, Output}
-import com.esotericsoftware.kryo.{Kryo, Serializer}
-
-import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
-import org.apache.spark.serializer.KryoRegistrator
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.columnar.ColumnarTestUtils._
-import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.{Logging, SparkFunSuite}
class ColumnTypeSuite extends SparkFunSuite with Logging {
private val DEFAULT_BUFFER_SIZE = 512
- private val MAP_GENERIC = GENERIC(MapType(IntegerType, StringType))
+ private val MAP_TYPE = MAP(MapType(IntegerType, StringType))
+ private val ARRAY_TYPE = ARRAY(ArrayType(IntegerType))
+ private val STRUCT_TYPE = STRUCT(StructType(StructField("a", StringType) :: Nil))
test("defaultSize") {
val checks = Map(
- BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4,
- LONG -> 8, FLOAT -> 4, DOUBLE -> 8,
- STRING -> 8, BINARY -> 16, FIXED_DECIMAL(15, 10) -> 8,
- MAP_GENERIC -> 16)
+ NULL-> 0, BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, LONG -> 8,
+ FLOAT -> 4, DOUBLE -> 8, COMPACT_DECIMAL(15, 10) -> 8, LARGE_DECIMAL(20, 10) -> 12,
+ STRING -> 8, BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 16, MAP_TYPE -> 32)
checks.foreach { case (columnType, expectedSize) =>
assertResult(expectedSize, s"Wrong defaultSize for $columnType") {
@@ -50,18 +47,19 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
}
test("actualSize") {
- def checkActualSize[JvmType](
- columnType: ColumnType[JvmType],
- value: JvmType,
+ def checkActualSize(
+ columnType: ColumnType[_],
+ value: Any,
expected: Int): Unit = {
assertResult(expected, s"Wrong actualSize for $columnType") {
val row = new GenericMutableRow(1)
- columnType.setField(row, 0, value)
+ row.update(0, CatalystTypeConverters.convertToCatalyst(value))
columnType.actualSize(row, 0)
}
}
+ checkActualSize(NULL, null, 0)
checkActualSize(BOOLEAN, true, 1)
checkActualSize(BYTE, Byte.MaxValue, 1)
checkActualSize(SHORT, Short.MaxValue, 2)
@@ -69,176 +67,65 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
checkActualSize(LONG, Long.MaxValue, 8)
checkActualSize(FLOAT, Float.MaxValue, 4)
checkActualSize(DOUBLE, Double.MaxValue, 8)
- checkActualSize(STRING, UTF8String.fromString("hello"), 4 + "hello".getBytes("utf-8").length)
+ checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length)
checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4)
- checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
-
- val generic = Map(1 -> "a")
- checkActualSize(MAP_GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8)
+ checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
+ checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5)
+ checkActualSize(ARRAY_TYPE, Array[Any](1), 16)
+ checkActualSize(MAP_TYPE, Map(1 -> "a"), 25)
+ checkActualSize(STRUCT_TYPE, Row("hello"), 28)
}
- testNativeColumnType(BOOLEAN)(
- (buffer: ByteBuffer, v: Boolean) => {
- buffer.put((if (v) 1 else 0).toByte)
- },
- (buffer: ByteBuffer) => {
- buffer.get() == 1
- })
-
- testNativeColumnType(BYTE)(_.put(_), _.get)
-
- testNativeColumnType(SHORT)(_.putShort(_), _.getShort)
-
- testNativeColumnType(INT)(_.putInt(_), _.getInt)
-
- testNativeColumnType(LONG)(_.putLong(_), _.getLong)
-
- testNativeColumnType(FLOAT)(_.putFloat(_), _.getFloat)
-
- testNativeColumnType(DOUBLE)(_.putDouble(_), _.getDouble)
-
- testNativeColumnType(FIXED_DECIMAL(15, 10))(
- (buffer: ByteBuffer, decimal: Decimal) => {
- buffer.putLong(decimal.toUnscaledLong)
- },
- (buffer: ByteBuffer) => {
- Decimal(buffer.getLong(), 15, 10)
- })
-
-
- testNativeColumnType(STRING)(
- (buffer: ByteBuffer, string: UTF8String) => {
- val bytes = string.getBytes
- buffer.putInt(bytes.length)
- buffer.put(bytes)
- },
- (buffer: ByteBuffer) => {
- val length = buffer.getInt()
- val bytes = new Array[Byte](length)
- buffer.get(bytes)
- UTF8String.fromBytes(bytes)
- })
-
- testColumnType[Array[Byte]](
- BINARY,
- (buffer: ByteBuffer, bytes: Array[Byte]) => {
- buffer.putInt(bytes.length).put(bytes)
- },
- (buffer: ByteBuffer) => {
- val length = buffer.getInt()
- val bytes = new Array[Byte](length)
- buffer.get(bytes, 0, length)
- bytes
- })
-
- test("GENERIC") {
- val buffer = ByteBuffer.allocate(512)
- val obj = Map(1 -> "spark", 2 -> "sql")
- val serializedObj = SparkSqlSerializer.serialize(obj)
-
- MAP_GENERIC.append(SparkSqlSerializer.serialize(obj), buffer)
- buffer.rewind()
-
- val length = buffer.getInt()
- assert(length === serializedObj.length)
-
- assertResult(obj, "Deserialized object didn't equal to the original object") {
- val bytes = new Array[Byte](length)
- buffer.get(bytes, 0, length)
- SparkSqlSerializer.deserialize(bytes)
- }
-
- buffer.rewind()
- buffer.putInt(serializedObj.length).put(serializedObj)
-
- assertResult(obj, "Deserialized object didn't equal to the original object") {
- buffer.rewind()
- SparkSqlSerializer.deserialize(MAP_GENERIC.extract(buffer))
- }
+ testNativeColumnType(BOOLEAN)
+ testNativeColumnType(BYTE)
+ testNativeColumnType(SHORT)
+ testNativeColumnType(INT)
+ testNativeColumnType(LONG)
+ testNativeColumnType(FLOAT)
+ testNativeColumnType(DOUBLE)
+ testNativeColumnType(COMPACT_DECIMAL(15, 10))
+ testNativeColumnType(STRING)
+
+ testColumnType(NULL)
+ testColumnType(BINARY)
+ testColumnType(LARGE_DECIMAL(20, 10))
+ testColumnType(STRUCT_TYPE)
+ testColumnType(ARRAY_TYPE)
+ testColumnType(MAP_TYPE)
+
+ def testNativeColumnType[T <: AtomicType](columnType: NativeColumnType[T]): Unit = {
+ testColumnType[T#InternalType](columnType)
}
- test("CUSTOM") {
- val conf = new SparkConf()
- conf.set("spark.kryo.registrator", "org.apache.spark.sql.columnar.Registrator")
- val serializer = new SparkSqlSerializer(conf).newInstance()
-
- val buffer = ByteBuffer.allocate(512)
- val obj = CustomClass(Int.MaxValue, Long.MaxValue)
- val serializedObj = serializer.serialize(obj).array()
-
- MAP_GENERIC.append(serializer.serialize(obj).array(), buffer)
- buffer.rewind()
-
- val length = buffer.getInt
- assert(length === serializedObj.length)
- assert(13 == length) // id (1) + int (4) + long (8)
-
- val genericSerializedObj = SparkSqlSerializer.serialize(obj)
- assert(length != genericSerializedObj.length)
- assert(length < genericSerializedObj.length)
-
- assertResult(obj, "Custom deserialized object didn't equal the original object") {
- val bytes = new Array[Byte](length)
- buffer.get(bytes, 0, length)
- serializer.deserialize(ByteBuffer.wrap(bytes))
- }
-
- buffer.rewind()
- buffer.putInt(serializedObj.length).put(serializedObj)
-
- assertResult(obj, "Custom deserialized object didn't equal the original object") {
- buffer.rewind()
- serializer.deserialize(ByteBuffer.wrap(MAP_GENERIC.extract(buffer)))
- }
- }
-
- def testNativeColumnType[T <: AtomicType](
- columnType: NativeColumnType[T])
- (putter: (ByteBuffer, T#InternalType) => Unit,
- getter: (ByteBuffer) => T#InternalType): Unit = {
-
- testColumnType[T#InternalType](columnType, putter, getter)
- }
-
- def testColumnType[JvmType](
- columnType: ColumnType[JvmType],
- putter: (ByteBuffer, JvmType) => Unit,
- getter: (ByteBuffer) => JvmType): Unit = {
+ def testColumnType[JvmType](columnType: ColumnType[JvmType]): Unit = {
val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE)
val seq = (0 until 4).map(_ => makeRandomValue(columnType))
+ val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType)
- test(s"$columnType.extract") {
+ test(s"$columnType append/extract") {
buffer.rewind()
- seq.foreach(putter(buffer, _))
+ seq.foreach(columnType.append(_, buffer))
buffer.rewind()
seq.foreach { expected =>
logInfo("buffer = " + buffer + ", expected = " + expected)
val extracted = columnType.extract(buffer)
assert(
- expected === extracted,
+ converter(expected) === converter(extracted),
"Extracted value didn't equal to the original one. " +
hexDump(expected) + " != " + hexDump(extracted) +
", buffer = " + dumpBuffer(buffer.duplicate().rewind().asInstanceOf[ByteBuffer]))
}
}
-
- test(s"$columnType.append") {
- buffer.rewind()
- seq.foreach(columnType.append(_, buffer))
-
- buffer.rewind()
- seq.foreach { expected =>
- assert(
- expected === getter(buffer),
- "Extracted value didn't equal to the original one")
- }
- }
}
private def hexDump(value: Any): String = {
- value.toString.map(ch => Integer.toHexString(ch & 0xffff)).mkString(" ")
+ if (value == null) {
+ ""
+ } else {
+ value.toString.map(ch => Integer.toHexString(ch & 0xffff)).mkString(" ")
+ }
}
private def dumpBuffer(buff: ByteBuffer): Any = {
@@ -253,33 +140,13 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
test("column type for decimal types with different precision") {
(1 to 18).foreach { i =>
- assertResult(FIXED_DECIMAL(i, 0)) {
+ assertResult(COMPACT_DECIMAL(i, 0)) {
ColumnType(DecimalType(i, 0))
}
}
- assertResult(GENERIC(DecimalType(19, 0))) {
+ assertResult(LARGE_DECIMAL(19, 0)) {
ColumnType(DecimalType(19, 0))
}
}
}
-
-private[columnar] final case class CustomClass(a: Int, b: Long)
-
-private[columnar] object CustomerSerializer extends Serializer[CustomClass] {
- override def write(kryo: Kryo, output: Output, t: CustomClass) {
- output.writeInt(t.a)
- output.writeLong(t.b)
- }
- override def read(kryo: Kryo, input: Input, aClass: Class[CustomClass]): CustomClass = {
- val a = input.readInt()
- val b = input.readLong()
- CustomClass(a, b)
- }
-}
-
-private[columnar] final class Registrator extends KryoRegistrator {
- override def registerClasses(kryo: Kryo) {
- kryo.register(classOf[CustomClass], CustomerSerializer)
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
index 123a7053c0..964cdb52b2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
@@ -21,8 +21,8 @@ import scala.collection.immutable.HashSet
import scala.util.Random
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
-import org.apache.spark.sql.types.{AtomicType, Decimal}
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, GenericMutableRow}
+import org.apache.spark.sql.types.{ArrayBasedMapData, GenericArrayData, AtomicType, Decimal}
import org.apache.spark.unsafe.types.UTF8String
object ColumnarTestUtils {
@@ -40,6 +40,7 @@ object ColumnarTestUtils {
}
(columnType match {
+ case NULL => null
case BOOLEAN => Random.nextBoolean()
case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort
@@ -49,10 +50,15 @@ object ColumnarTestUtils {
case DOUBLE => Random.nextDouble()
case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32)))
case BINARY => randomBytes(Random.nextInt(32))
- case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
- case _ =>
- // Using a random one-element map instead of an arbitrary object
- Map(Random.nextInt() -> Random.nextString(Random.nextInt(32)))
+ case COMPACT_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
+ case LARGE_DECIMAL(precision, scale) => Decimal(Random.nextLong(), precision, scale)
+ case STRUCT(_) =>
+ new GenericInternalRow(Array[Any](UTF8String.fromString(Random.nextString(10))))
+ case ARRAY(_) =>
+ new GenericArrayData(Array[Any](Random.nextInt(), Random.nextInt()))
+ case MAP(_) =>
+ ArrayBasedMapData(
+ Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32)))))
}).asInstanceOf[JvmType]
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index ea5dd2be33..6265e40a0a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -157,7 +157,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
// Create a RDD for the schema
val rdd =
- sparkContext.parallelize((1 to 100), 10).map { i =>
+ sparkContext.parallelize((1 to 10000), 10).map { i =>
Row(
s"str${i}: test cache.",
s"binary${i}: test cache.".getBytes("UTF-8"),
@@ -172,9 +172,9 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
BigDecimal(Long.MaxValue.toString + ".12345"),
new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"),
new Date(i),
- new Timestamp(i),
- (1 to i).toSeq,
- (0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap,
+ new Timestamp(i * 1000000L),
+ (i to i + 10).toSeq,
+ (i to i + 10).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap,
Row((i - 0.25).toFloat, Seq(true, false, null)))
}
sqlContext.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
index a3a23d37d7..78cebbf3cc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
@@ -20,8 +20,9 @@ package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
-import org.apache.spark.sql.types.{ArrayType, StringType}
+import org.apache.spark.sql.types._
class TestNullableColumnAccessor[JvmType](
buffer: ByteBuffer,
@@ -40,8 +41,10 @@ class NullableColumnAccessorSuite extends SparkFunSuite {
import org.apache.spark.sql.columnar.ColumnarTestUtils._
Seq(
- BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE,
- STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType)))
+ NULL, BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE,
+ STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
+ STRUCT(StructType(StructField("a", StringType) :: Nil)),
+ ARRAY(ArrayType(IntegerType)), MAP(MapType(IntegerType, StringType)))
.foreach {
testNullableColumnAccessor(_)
}
@@ -69,11 +72,13 @@ class NullableColumnAccessorSuite extends SparkFunSuite {
val accessor = TestNullableColumnAccessor(builder.build(), columnType)
val row = new GenericMutableRow(1)
+ val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType)
(0 until 4).foreach { _ =>
assert(accessor.hasNext)
accessor.extractTo(row, 0)
- assert(row.get(0, columnType.dataType) === randomRow.get(0, columnType.dataType))
+ assert(converter(row.get(0, columnType.dataType))
+ === converter(randomRow.get(0, columnType.dataType)))
assert(accessor.hasNext)
accessor.extractTo(row, 0)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
index 9557eead27..fba08e626d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
@@ -18,7 +18,8 @@
package org.apache.spark.sql.columnar
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.execution.SparkSqlSerializer
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.types._
class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType])
@@ -35,11 +36,13 @@ object TestNullableColumnBuilder {
}
class NullableColumnBuilderSuite extends SparkFunSuite {
- import ColumnarTestUtils._
+ import org.apache.spark.sql.columnar.ColumnarTestUtils._
Seq(
BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE,
- STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType)))
+ STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
+ STRUCT(StructType(StructField("a", StringType) :: Nil)),
+ ARRAY(ArrayType(IntegerType)), MAP(MapType(IntegerType, StringType)))
.foreach {
testNullableColumnBuilder(_)
}
@@ -74,6 +77,8 @@ class NullableColumnBuilderSuite extends SparkFunSuite {
val columnBuilder = TestNullableColumnBuilder(columnType)
val randomRow = makeRandomRow(columnType)
val nullRow = makeNullRow(1)
+ val dataType = columnType.dataType
+ val converter = CatalystTypeConverters.createToScalaConverter(dataType)
(0 until 4).foreach { _ =>
columnBuilder.appendFrom(randomRow, 0)
@@ -88,14 +93,10 @@ class NullableColumnBuilderSuite extends SparkFunSuite {
(1 to 7 by 2).foreach(assertResult(_, "Wrong null position")(buffer.getInt()))
// For non-null values
+ val actual = new GenericMutableRow(new Array[Any](1))
(0 until 4).foreach { _ =>
- val actual = if (columnType.isInstanceOf[GENERIC]) {
- SparkSqlSerializer.deserialize[Any](columnType.extract(buffer).asInstanceOf[Array[Byte]])
- } else {
- columnType.extract(buffer)
- }
-
- assert(actual === randomRow.get(0, columnType.dataType),
+ columnType.extract(buffer, actual, 0)
+ assert(converter(actual.get(0, dataType)) === converter(randomRow.get(0, dataType)),
"Extracted value didn't equal to the original one")
}