aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2015-03-14 19:53:54 +0800
committerCheng Lian <lian@databricks.com>2015-03-14 19:53:54 +0800
commit5be6b0e4f48aca12fcd47c1b77c4675ad651c332 (patch)
tree41d92dac84a474d14f30b190e2b3e7cbd106ae59 /sql
parentee15404a2b0009fc70119ac7af69137b54890d48 (diff)
downloadspark-5be6b0e4f48aca12fcd47c1b77c4675ad651c332.tar.gz
spark-5be6b0e4f48aca12fcd47c1b77c4675ad651c332.tar.bz2
spark-5be6b0e4f48aca12fcd47c1b77c4675ad651c332.zip
[SPARK-6195] [SQL] Adds in-memory column type for fixed-precision decimals
This PR adds a specialized in-memory column type for fixed-precision decimals. For all other column types, a single integer column type ID is enough to determine which column type to use. However, this doesn't apply to fixed-precision decimal types with different precision and scale parameters. Moreover, according to the previous design, there seems no trivial way to encode precision and scale information into the columnar byte buffer. On the other hand, considering we always know the data type of the column to be built / scanned ahead of time. This PR no longer use column type ID to construct `ColumnBuilder`s and `ColumnAccessor`s, but resorts to the actual column data type. In this way, we can pass precision / scale information along the way. The column type ID is now not used anymore and can be removed in a future PR. ### Micro benchmark result The following micro benchmark builds a simple table with 2 million decimals (precision = 10, scale = 0), cache it in memory, then count all the rows. Code (simply paste it into Spark shell): ```scala import sc._ import sqlContext._ import sqlContext.implicits._ import org.apache.spark.sql.types._ import com.google.common.base.Stopwatch def benchmark(n: Int)(f: => Long) { val stopwatch = new Stopwatch() def run() = { stopwatch.reset() stopwatch.start() f stopwatch.stop() stopwatch.elapsedMillis() } val records = (0 until n).map(_ => run()) (0 until n).foreach(i => println(s"Round $i: ${records(i)} ms")) println(s"Average: ${records.sum / n.toDouble} ms") } // Explicit casting is required because ScalaReflection can't inspect decimal precision parallelize(1 to 2000000) .map(i => Tuple1(Decimal(i, 10, 0))) .toDF("dec") .select($"dec" cast DecimalType(10, 0)) .registerTempTable("dec") sql("CACHE TABLE dec") val df = table("dec") // Warm up df.count() df.count() benchmark(5) { df.count() } ``` With `FIXED_DECIMAL` column type: - Round 0: 75 ms - Round 1: 97 ms - Round 2: 75 ms - Round 3: 70 ms - Round 4: 72 ms - Average: 77.8 ms Without `FIXED_DECIMAL` column type: - Round 0: 1233 ms - Round 1: 1170 ms - Round 2: 1171 ms - Round 3: 1141 ms - Round 4: 1141 ms - Average: 1171.2 ms <!-- Reviewable:start --> [<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/4938) <!-- Reviewable:end --> Author: Cheng Lian <lian@databricks.com> Closes #4938 from liancheng/decimal-column-type and squashes the following commits: fef5338 [Cheng Lian] Updates fixed decimal column type related test cases e08ab5b [Cheng Lian] Only resorts to FIXED_DECIMAL when the value can be held in a long 4db713d [Cheng Lian] Adds in-memory column type for fixed-precision decimals
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala43
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala39
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala55
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala46
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala23
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala17
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala3
11 files changed, 179 insertions, 76 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
index 91c4c105b1..b615eaa0dc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
@@ -21,7 +21,7 @@ import java.nio.{ByteBuffer, ByteOrder}
import org.apache.spark.sql.catalyst.expressions.MutableRow
import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor
-import org.apache.spark.sql.types.{BinaryType, DataType, NativeType}
+import org.apache.spark.sql.types._
/**
* An `Iterator` like trait used to extract values from columnar byte buffer. When a value is
@@ -89,6 +89,9 @@ private[sql] class DoubleColumnAccessor(buffer: ByteBuffer)
private[sql] class FloatColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, FLOAT)
+private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int)
+ extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale))
+
private[sql] class StringColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, STRING)
@@ -107,24 +110,28 @@ private[sql] class GenericColumnAccessor(buffer: ByteBuffer)
with NullableColumnAccessor
private[sql] object ColumnAccessor {
- def apply(buffer: ByteBuffer): ColumnAccessor = {
+ def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = {
val dup = buffer.duplicate().order(ByteOrder.nativeOrder)
- // The first 4 bytes in the buffer indicate the column type.
- val columnTypeId = dup.getInt()
-
- columnTypeId match {
- case INT.typeId => new IntColumnAccessor(dup)
- case LONG.typeId => new LongColumnAccessor(dup)
- case FLOAT.typeId => new FloatColumnAccessor(dup)
- case DOUBLE.typeId => new DoubleColumnAccessor(dup)
- case BOOLEAN.typeId => new BooleanColumnAccessor(dup)
- case BYTE.typeId => new ByteColumnAccessor(dup)
- case SHORT.typeId => new ShortColumnAccessor(dup)
- case STRING.typeId => new StringColumnAccessor(dup)
- case DATE.typeId => new DateColumnAccessor(dup)
- case TIMESTAMP.typeId => new TimestampColumnAccessor(dup)
- case BINARY.typeId => new BinaryColumnAccessor(dup)
- case GENERIC.typeId => new GenericColumnAccessor(dup)
+
+ // The first 4 bytes in the buffer indicate the column type. This field is not used now,
+ // because we always know the data type of the column ahead of time.
+ dup.getInt()
+
+ dataType match {
+ case IntegerType => new IntColumnAccessor(dup)
+ case LongType => new LongColumnAccessor(dup)
+ case FloatType => new FloatColumnAccessor(dup)
+ case DoubleType => new DoubleColumnAccessor(dup)
+ case BooleanType => new BooleanColumnAccessor(dup)
+ case ByteType => new ByteColumnAccessor(dup)
+ case ShortType => new ShortColumnAccessor(dup)
+ case StringType => new StringColumnAccessor(dup)
+ case BinaryType => new BinaryColumnAccessor(dup)
+ case DateType => new DateColumnAccessor(dup)
+ case TimestampType => new TimestampColumnAccessor(dup)
+ case DecimalType.Fixed(precision, scale) if precision < 19 =>
+ new FixedDecimalColumnAccessor(dup, precision, scale)
+ case _ => new GenericColumnAccessor(dup)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
index 3a4977b836..d8d24a5773 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
@@ -106,6 +106,13 @@ private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleCol
private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT)
+private[sql] class FixedDecimalColumnBuilder(
+ precision: Int,
+ scale: Int)
+ extends NativeColumnBuilder(
+ new FixedDecimalColumnStats,
+ FIXED_DECIMAL(precision, scale))
+
private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING)
private[sql] class DateColumnBuilder extends NativeColumnBuilder(new DateColumnStats, DATE)
@@ -139,25 +146,25 @@ private[sql] object ColumnBuilder {
}
def apply(
- typeId: Int,
+ dataType: DataType,
initialSize: Int = 0,
columnName: String = "",
useCompression: Boolean = false): ColumnBuilder = {
-
- val builder = (typeId match {
- case INT.typeId => new IntColumnBuilder
- case LONG.typeId => new LongColumnBuilder
- case FLOAT.typeId => new FloatColumnBuilder
- case DOUBLE.typeId => new DoubleColumnBuilder
- case BOOLEAN.typeId => new BooleanColumnBuilder
- case BYTE.typeId => new ByteColumnBuilder
- case SHORT.typeId => new ShortColumnBuilder
- case STRING.typeId => new StringColumnBuilder
- case BINARY.typeId => new BinaryColumnBuilder
- case GENERIC.typeId => new GenericColumnBuilder
- case DATE.typeId => new DateColumnBuilder
- case TIMESTAMP.typeId => new TimestampColumnBuilder
- }).asInstanceOf[ColumnBuilder]
+ val builder: ColumnBuilder = dataType match {
+ case IntegerType => new IntColumnBuilder
+ case LongType => new LongColumnBuilder
+ case DoubleType => new DoubleColumnBuilder
+ case BooleanType => new BooleanColumnBuilder
+ case ByteType => new ByteColumnBuilder
+ case ShortType => new ShortColumnBuilder
+ case StringType => new StringColumnBuilder
+ case BinaryType => new BinaryColumnBuilder
+ case DateType => new DateColumnBuilder
+ case TimestampType => new TimestampColumnBuilder
+ case DecimalType.Fixed(precision, scale) if precision < 19 =>
+ new FixedDecimalColumnBuilder(precision, scale)
+ case _ => new GenericColumnBuilder
+ }
builder.initialize(initialSize, columnName, useCompression)
builder
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
index cad0667b46..04047b9c06 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -181,6 +181,23 @@ private[sql] class FloatColumnStats extends ColumnStats {
def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
}
+private[sql] class FixedDecimalColumnStats extends ColumnStats {
+ protected var upper: Decimal = null
+ protected var lower: Decimal = null
+
+ override def gatherStats(row: Row, ordinal: Int): Unit = {
+ super.gatherStats(row, ordinal)
+ if (!row.isNullAt(ordinal)) {
+ val value = row(ordinal).asInstanceOf[Decimal]
+ if (upper == null || value.compareTo(upper) > 0) upper = value
+ if (lower == null || value.compareTo(lower) < 0) lower = value
+ sizeInBytes += FIXED_DECIMAL.defaultSize
+ }
+ }
+
+ override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes)
+}
+
private[sql] class IntColumnStats extends ColumnStats {
protected var upper = Int.MinValue
protected var lower = Int.MaxValue
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index db5bc0de36..36ea1c77e0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -373,6 +373,33 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 12) {
}
}
+private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int)
+ extends NativeColumnType(
+ DecimalType(Some(PrecisionInfo(precision, scale))),
+ 10,
+ FIXED_DECIMAL.defaultSize) {
+
+ override def extract(buffer: ByteBuffer): Decimal = {
+ Decimal(buffer.getLong(), precision, scale)
+ }
+
+ override def append(v: Decimal, buffer: ByteBuffer): Unit = {
+ buffer.putLong(v.toUnscaledLong)
+ }
+
+ override def getField(row: Row, ordinal: Int): Decimal = {
+ row(ordinal).asInstanceOf[Decimal]
+ }
+
+ override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = {
+ row(ordinal) = value
+ }
+}
+
+private[sql] object FIXED_DECIMAL {
+ val defaultSize = 8
+}
+
private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
typeId: Int,
defaultSize: Int)
@@ -394,7 +421,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
}
}
-private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](10, 16) {
+private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](11, 16) {
override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = {
row(ordinal) = value
}
@@ -405,7 +432,7 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](10, 16)
// Used to process generic objects (all types other than those listed above). Objects should be
// serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized
// byte array.
-private[sql] object GENERIC extends ByteArrayColumnType[DataType](11, 16) {
+private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) {
override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = {
row(ordinal) = SparkSqlSerializer.deserialize[Any](value)
}
@@ -416,18 +443,20 @@ private[sql] object GENERIC extends ByteArrayColumnType[DataType](11, 16) {
private[sql] object ColumnType {
def apply(dataType: DataType): ColumnType[_, _] = {
dataType match {
- case IntegerType => INT
- case LongType => LONG
- case FloatType => FLOAT
- case DoubleType => DOUBLE
- case BooleanType => BOOLEAN
- case ByteType => BYTE
- case ShortType => SHORT
- case StringType => STRING
- case BinaryType => BINARY
- case DateType => DATE
+ case IntegerType => INT
+ case LongType => LONG
+ case FloatType => FLOAT
+ case DoubleType => DOUBLE
+ case BooleanType => BOOLEAN
+ case ByteType => BYTE
+ case ShortType => SHORT
+ case StringType => STRING
+ case BinaryType => BINARY
+ case DateType => DATE
case TimestampType => TIMESTAMP
- case _ => GENERIC
+ case DecimalType.Fixed(precision, scale) if precision < 19 =>
+ FIXED_DECIMAL(precision, scale)
+ case _ => GENERIC
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index 8944a32bc3..387faee12b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -113,7 +113,7 @@ private[sql] case class InMemoryRelation(
val columnBuilders = output.map { attribute =>
val columnType = ColumnType(attribute.dataType)
val initialBufferSize = columnType.defaultSize * batchSize
- ColumnBuilder(columnType.typeId, initialBufferSize, attribute.name, useCompression)
+ ColumnBuilder(attribute.dataType, initialBufferSize, attribute.name, useCompression)
}.toArray
var rowCount = 0
@@ -274,8 +274,10 @@ private[sql] case class InMemoryColumnarTableScan(
def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]) = {
val rows = cacheBatches.flatMap { cachedBatch =>
// Build column accessors
- val columnAccessors = requestedColumnIndices.map { batch =>
- ColumnAccessor(ByteBuffer.wrap(cachedBatch.buffers(batch)))
+ val columnAccessors = requestedColumnIndices.map { batchColumnIndex =>
+ ColumnAccessor(
+ relation.output(batchColumnIndex).dataType,
+ ByteBuffer.wrap(cachedBatch.buffers(batchColumnIndex)))
}
// Extract rows via column accessors
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index 581fccf8ee..fec487f1d2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -29,6 +29,7 @@ class ColumnStatsSuite extends FunSuite {
testColumnStats(classOf[LongColumnStats], LONG, Row(Long.MaxValue, Long.MinValue, 0))
testColumnStats(classOf[FloatColumnStats], FLOAT, Row(Float.MaxValue, Float.MinValue, 0))
testColumnStats(classOf[DoubleColumnStats], DOUBLE, Row(Double.MaxValue, Double.MinValue, 0))
+ testColumnStats(classOf[FixedDecimalColumnStats], FIXED_DECIMAL(15, 10), Row(null, null, 0))
testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0))
testColumnStats(classOf[DateColumnStats], DATE, Row(Int.MaxValue, Int.MinValue, 0))
testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index 9ce845912f..5f08834f73 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -33,8 +33,9 @@ class ColumnTypeSuite extends FunSuite with Logging {
test("defaultSize") {
val checks = Map(
- INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, BOOLEAN -> 1,
- STRING -> 8, DATE -> 4, TIMESTAMP -> 12, BINARY -> 16, GENERIC -> 16)
+ INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4,
+ FIXED_DECIMAL(15, 10) -> 8, BOOLEAN -> 1, STRING -> 8, DATE -> 4, TIMESTAMP -> 12,
+ BINARY -> 16, GENERIC -> 16)
checks.foreach { case (columnType, expectedSize) =>
assertResult(expectedSize, s"Wrong defaultSize for $columnType") {
@@ -56,15 +57,16 @@ class ColumnTypeSuite extends FunSuite with Logging {
}
}
- checkActualSize(INT, Int.MaxValue, 4)
- checkActualSize(SHORT, Short.MaxValue, 2)
- checkActualSize(LONG, Long.MaxValue, 8)
- checkActualSize(BYTE, Byte.MaxValue, 1)
- checkActualSize(DOUBLE, Double.MaxValue, 8)
- checkActualSize(FLOAT, Float.MaxValue, 4)
- checkActualSize(BOOLEAN, true, 1)
- checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length)
- checkActualSize(DATE, 0, 4)
+ checkActualSize(INT, Int.MaxValue, 4)
+ checkActualSize(SHORT, Short.MaxValue, 2)
+ checkActualSize(LONG, Long.MaxValue, 8)
+ checkActualSize(BYTE, Byte.MaxValue, 1)
+ checkActualSize(DOUBLE, Double.MaxValue, 8)
+ checkActualSize(FLOAT, Float.MaxValue, 4)
+ checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
+ checkActualSize(BOOLEAN, true, 1)
+ checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length)
+ checkActualSize(DATE, 0, 4)
checkActualSize(TIMESTAMP, new Timestamp(0L), 12)
val binary = Array.fill[Byte](4)(0: Byte)
@@ -93,12 +95,20 @@ class ColumnTypeSuite extends FunSuite with Logging {
testNativeColumnType[DoubleType.type](DOUBLE, _.putDouble(_), _.getDouble)
+ testNativeColumnType[DecimalType](
+ FIXED_DECIMAL(15, 10),
+ (buffer: ByteBuffer, decimal: Decimal) => {
+ buffer.putLong(decimal.toUnscaledLong)
+ },
+ (buffer: ByteBuffer) => {
+ Decimal(buffer.getLong(), 15, 10)
+ })
+
testNativeColumnType[FloatType.type](FLOAT, _.putFloat(_), _.getFloat)
testNativeColumnType[StringType.type](
STRING,
(buffer: ByteBuffer, string: String) => {
-
val bytes = string.getBytes("utf-8")
buffer.putInt(bytes.length)
buffer.put(bytes)
@@ -206,4 +216,16 @@ class ColumnTypeSuite extends FunSuite with Logging {
if (sb.nonEmpty) sb.setLength(sb.length - 1)
sb.toString()
}
+
+ test("column type for decimal types with different precision") {
+ (1 to 18).foreach { i =>
+ assertResult(FIXED_DECIMAL(i, 0)) {
+ ColumnType(DecimalType(i, 0))
+ }
+ }
+
+ assertResult(GENERIC) {
+ ColumnType(DecimalType(19, 0))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
index 60ed28cc97..c7a40845db 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
@@ -24,7 +24,7 @@ import scala.util.Random
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
-import org.apache.spark.sql.types.{DataType, NativeType}
+import org.apache.spark.sql.types.{Decimal, DataType, NativeType}
object ColumnarTestUtils {
def makeNullRow(length: Int) = {
@@ -41,16 +41,17 @@ object ColumnarTestUtils {
}
(columnType match {
- case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
- case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort
- case INT => Random.nextInt()
- case LONG => Random.nextLong()
- case FLOAT => Random.nextFloat()
- case DOUBLE => Random.nextDouble()
- case STRING => Random.nextString(Random.nextInt(32))
- case BOOLEAN => Random.nextBoolean()
- case BINARY => randomBytes(Random.nextInt(32))
- case DATE => Random.nextInt()
+ case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
+ case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort
+ case INT => Random.nextInt()
+ case LONG => Random.nextLong()
+ case FLOAT => Random.nextFloat()
+ case DOUBLE => Random.nextDouble()
+ case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
+ case STRING => Random.nextString(Random.nextInt(32))
+ case BOOLEAN => Random.nextBoolean()
+ case BINARY => randomBytes(Random.nextInt(32))
+ case DATE => Random.nextInt()
case TIMESTAMP =>
val timestamp = new Timestamp(Random.nextLong())
timestamp.setNanos(Random.nextInt(999999999))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index 38b0f666ab..27dfabca90 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -17,11 +17,11 @@
package org.apache.spark.sql.columnar
-import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.TestSQLContext.implicits._
+import org.apache.spark.sql.types.{DecimalType, Decimal}
import org.apache.spark.sql.{QueryTest, TestData}
import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
@@ -117,4 +117,19 @@ class InMemoryColumnarQuerySuite extends QueryTest {
complexData.count()
complexData.unpersist()
}
+
+ test("decimal type") {
+ // Casting is required here because ScalaReflection can't capture decimal precision information.
+ val df = (1 to 10)
+ .map(i => Tuple1(Decimal(i, 15, 10)))
+ .toDF("dec")
+ .select($"dec" cast DecimalType(15, 10))
+
+ assert(df.schema.head.dataType === DecimalType(15, 10))
+
+ df.cache().registerTempTable("test_fixed_decimal")
+ checkAnswer(
+ sql("SELECT * FROM test_fixed_decimal"),
+ (1 to 10).map(i => Row(Decimal(i, 15, 10).toJavaBigDecimal)))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
index f95c895587..bb30535527 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
@@ -42,7 +42,8 @@ class NullableColumnAccessorSuite extends FunSuite {
import ColumnarTestUtils._
Seq(
- INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, DATE, TIMESTAMP
+ INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, FIXED_DECIMAL(15, 10), BINARY, GENERIC,
+ DATE, TIMESTAMP
).foreach {
testNullableColumnAccessor(_)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
index 80bd5c9457..75a4749868 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
@@ -38,7 +38,8 @@ class NullableColumnBuilderSuite extends FunSuite {
import ColumnarTestUtils._
Seq(
- INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, DATE, TIMESTAMP
+ INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, FIXED_DECIMAL(15, 10), BINARY, GENERIC,
+ DATE, TIMESTAMP
).foreach {
testNullableColumnBuilder(_)
}