aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-06-29 11:41:26 -0700
committerCheng Lian <lian@databricks.com>2015-06-29 11:41:26 -0700
commited413bcc78d8d97a1a0cd0871d7a20f7170476d0 (patch)
treecc63e6f12c4c022af716832703eb08867a1ba95e /sql/core
parentea88b1a5077e6ba980b0de6d3bc508c62285ba4c (diff)
downloadspark-ed413bcc78d8d97a1a0cd0871d7a20f7170476d0.tar.gz
spark-ed413bcc78d8d97a1a0cd0871d7a20f7170476d0.tar.bz2
spark-ed413bcc78d8d97a1a0cd0871d7a20f7170476d0.zip
[SPARK-8692] [SQL] re-order the case statements that handling catalyst data types
use same order: boolean, byte, short, int, date, long, timestamp, float, double, string, binary, decimal. Then we can easily check whether some data types are missing by just one glance, and make sure we handle data/timestamp just as int/long. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7073 from cloud-fan/fix-date and squashes the following commits: 463044d [Wenchen Fan] fix style 51cd347 [Wenchen Fan] refactor handling of date and timestmap
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala42
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala74
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala82
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala34
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala54
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala6
12 files changed, 164 insertions, 195 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
index 64449b2659..931469bed6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
@@ -71,44 +71,44 @@ private[sql] abstract class NativeColumnAccessor[T <: AtomicType](
private[sql] class BooleanColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, BOOLEAN)
-private[sql] class IntColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, INT)
+private[sql] class ByteColumnAccessor(buffer: ByteBuffer)
+ extends NativeColumnAccessor(buffer, BYTE)
private[sql] class ShortColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, SHORT)
+private[sql] class IntColumnAccessor(buffer: ByteBuffer)
+ extends NativeColumnAccessor(buffer, INT)
+
private[sql] class LongColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, LONG)
-private[sql] class ByteColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, BYTE)
-
-private[sql] class DoubleColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, DOUBLE)
-
private[sql] class FloatColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, FLOAT)
-private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int)
- extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale))
+private[sql] class DoubleColumnAccessor(buffer: ByteBuffer)
+ extends NativeColumnAccessor(buffer, DOUBLE)
private[sql] class StringColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, STRING)
-private[sql] class DateColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, DATE)
-
-private[sql] class TimestampColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, TIMESTAMP)
-
private[sql] class BinaryColumnAccessor(buffer: ByteBuffer)
extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer, BINARY)
with NullableColumnAccessor
+private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int)
+ extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale))
+
private[sql] class GenericColumnAccessor(buffer: ByteBuffer)
extends BasicColumnAccessor[DataType, Array[Byte]](buffer, GENERIC)
with NullableColumnAccessor
+private[sql] class DateColumnAccessor(buffer: ByteBuffer)
+ extends NativeColumnAccessor(buffer, DATE)
+
+private[sql] class TimestampColumnAccessor(buffer: ByteBuffer)
+ extends NativeColumnAccessor(buffer, TIMESTAMP)
+
private[sql] object ColumnAccessor {
def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = {
val dup = buffer.duplicate().order(ByteOrder.nativeOrder)
@@ -118,17 +118,17 @@ private[sql] object ColumnAccessor {
dup.getInt()
dataType match {
+ case BooleanType => new BooleanColumnAccessor(dup)
+ case ByteType => new ByteColumnAccessor(dup)
+ case ShortType => new ShortColumnAccessor(dup)
case IntegerType => new IntColumnAccessor(dup)
+ case DateType => new DateColumnAccessor(dup)
case LongType => new LongColumnAccessor(dup)
+ case TimestampType => new TimestampColumnAccessor(dup)
case FloatType => new FloatColumnAccessor(dup)
case DoubleType => new DoubleColumnAccessor(dup)
- case BooleanType => new BooleanColumnAccessor(dup)
- case ByteType => new ByteColumnAccessor(dup)
- case ShortType => new ShortColumnAccessor(dup)
case StringType => new StringColumnAccessor(dup)
case BinaryType => new BinaryColumnAccessor(dup)
- case DateType => new DateColumnAccessor(dup)
- case TimestampType => new TimestampColumnAccessor(dup)
case DecimalType.Fixed(precision, scale) if precision < 19 =>
new FixedDecimalColumnAccessor(dup, precision, scale)
case _ => new GenericColumnAccessor(dup)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
index 1949625699..087c522397 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
@@ -94,17 +94,21 @@ private[sql] abstract class NativeColumnBuilder[T <: AtomicType](
private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN)
-private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT)
+private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE)
private[sql] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT)
+private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT)
+
private[sql] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG)
-private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE)
+private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT)
private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE)
-private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT)
+private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING)
+
+private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY)
private[sql] class FixedDecimalColumnBuilder(
precision: Int,
@@ -113,19 +117,15 @@ private[sql] class FixedDecimalColumnBuilder(
new FixedDecimalColumnStats,
FIXED_DECIMAL(precision, scale))
-private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING)
+// TODO (lian) Add support for array, struct and map
+private[sql] class GenericColumnBuilder
+ extends ComplexColumnBuilder(new GenericColumnStats, GENERIC)
private[sql] class DateColumnBuilder extends NativeColumnBuilder(new DateColumnStats, DATE)
private[sql] class TimestampColumnBuilder
extends NativeColumnBuilder(new TimestampColumnStats, TIMESTAMP)
-private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY)
-
-// TODO (lian) Add support for array, struct and map
-private[sql] class GenericColumnBuilder
- extends ComplexColumnBuilder(new GenericColumnStats, GENERIC)
-
private[sql] object ColumnBuilder {
val DEFAULT_INITIAL_BUFFER_SIZE = 1024 * 1024
@@ -151,17 +151,17 @@ private[sql] object ColumnBuilder {
columnName: String = "",
useCompression: Boolean = false): ColumnBuilder = {
val builder: ColumnBuilder = dataType match {
+ case BooleanType => new BooleanColumnBuilder
+ case ByteType => new ByteColumnBuilder
+ case ShortType => new ShortColumnBuilder
case IntegerType => new IntColumnBuilder
+ case DateType => new DateColumnBuilder
case LongType => new LongColumnBuilder
+ case TimestampType => new TimestampColumnBuilder
case FloatType => new FloatColumnBuilder
case DoubleType => new DoubleColumnBuilder
- case BooleanType => new BooleanColumnBuilder
- case ByteType => new ByteColumnBuilder
- case ShortType => new ShortColumnBuilder
case StringType => new StringColumnBuilder
case BinaryType => new BinaryColumnBuilder
- case DateType => new DateColumnBuilder
- case TimestampType => new TimestampColumnBuilder
case DecimalType.Fixed(precision, scale) if precision < 19 =>
new FixedDecimalColumnBuilder(precision, scale)
case _ => new GenericColumnBuilder
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
index 1bce214d1d..00374d1fa3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -132,17 +132,17 @@ private[sql] class ShortColumnStats extends ColumnStats {
InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
-private[sql] class LongColumnStats extends ColumnStats {
- protected var upper = Long.MinValue
- protected var lower = Long.MaxValue
+private[sql] class IntColumnStats extends ColumnStats {
+ protected var upper = Int.MinValue
+ protected var lower = Int.MaxValue
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
- val value = row.getLong(ordinal)
+ val value = row.getInt(ordinal)
if (value > upper) upper = value
if (value < lower) lower = value
- sizeInBytes += LONG.defaultSize
+ sizeInBytes += INT.defaultSize
}
}
@@ -150,17 +150,17 @@ private[sql] class LongColumnStats extends ColumnStats {
InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
-private[sql] class DoubleColumnStats extends ColumnStats {
- protected var upper = Double.MinValue
- protected var lower = Double.MaxValue
+private[sql] class LongColumnStats extends ColumnStats {
+ protected var upper = Long.MinValue
+ protected var lower = Long.MaxValue
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
- val value = row.getDouble(ordinal)
+ val value = row.getLong(ordinal)
if (value > upper) upper = value
if (value < lower) lower = value
- sizeInBytes += DOUBLE.defaultSize
+ sizeInBytes += LONG.defaultSize
}
}
@@ -186,35 +186,17 @@ private[sql] class FloatColumnStats extends ColumnStats {
InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
-private[sql] class FixedDecimalColumnStats extends ColumnStats {
- protected var upper: Decimal = null
- protected var lower: Decimal = null
-
- override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
- super.gatherStats(row, ordinal)
- if (!row.isNullAt(ordinal)) {
- val value = row(ordinal).asInstanceOf[Decimal]
- if (upper == null || value.compareTo(upper) > 0) upper = value
- if (lower == null || value.compareTo(lower) < 0) lower = value
- sizeInBytes += FIXED_DECIMAL.defaultSize
- }
- }
-
- override def collectedStatistics: InternalRow =
- InternalRow(lower, upper, nullCount, count, sizeInBytes)
-}
-
-private[sql] class IntColumnStats extends ColumnStats {
- protected var upper = Int.MinValue
- protected var lower = Int.MaxValue
+private[sql] class DoubleColumnStats extends ColumnStats {
+ protected var upper = Double.MinValue
+ protected var lower = Double.MaxValue
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
- val value = row.getInt(ordinal)
+ val value = row.getDouble(ordinal)
if (value > upper) upper = value
if (value < lower) lower = value
- sizeInBytes += INT.defaultSize
+ sizeInBytes += DOUBLE.defaultSize
}
}
@@ -240,10 +222,6 @@ private[sql] class StringColumnStats extends ColumnStats {
InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
-private[sql] class DateColumnStats extends IntColumnStats
-
-private[sql] class TimestampColumnStats extends LongColumnStats
-
private[sql] class BinaryColumnStats extends ColumnStats {
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
@@ -256,6 +234,24 @@ private[sql] class BinaryColumnStats extends ColumnStats {
InternalRow(null, null, nullCount, count, sizeInBytes)
}
+private[sql] class FixedDecimalColumnStats extends ColumnStats {
+ protected var upper: Decimal = null
+ protected var lower: Decimal = null
+
+ override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
+ super.gatherStats(row, ordinal)
+ if (!row.isNullAt(ordinal)) {
+ val value = row(ordinal).asInstanceOf[Decimal]
+ if (upper == null || value.compareTo(upper) > 0) upper = value
+ if (lower == null || value.compareTo(lower) < 0) lower = value
+ sizeInBytes += FIXED_DECIMAL.defaultSize
+ }
+ }
+
+ override def collectedStatistics: InternalRow =
+ InternalRow(lower, upper, nullCount, count, sizeInBytes)
+}
+
private[sql] class GenericColumnStats extends ColumnStats {
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
@@ -267,3 +263,7 @@ private[sql] class GenericColumnStats extends ColumnStats {
override def collectedStatistics: InternalRow =
InternalRow(null, null, nullCount, count, sizeInBytes)
}
+
+private[sql] class DateColumnStats extends IntColumnStats
+
+private[sql] class TimestampColumnStats extends LongColumnStats
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index 8bf2151e4d..fc72360c88 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -447,17 +447,17 @@ private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) {
private[sql] object ColumnType {
def apply(dataType: DataType): ColumnType[_, _] = {
dataType match {
+ case BooleanType => BOOLEAN
+ case ByteType => BYTE
+ case ShortType => SHORT
case IntegerType => INT
+ case DateType => DATE
case LongType => LONG
+ case TimestampType => TIMESTAMP
case FloatType => FLOAT
case DoubleType => DOUBLE
- case BooleanType => BOOLEAN
- case ByteType => BYTE
- case ShortType => SHORT
case StringType => STRING
case BinaryType => BINARY
- case DateType => DATE
- case TimestampType => TIMESTAMP
case DecimalType.Fixed(precision, scale) if precision < 19 =>
FIXED_DECIMAL(precision, scale)
case _ => GENERIC
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
index 74a22353b1..056d435eec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
@@ -237,7 +237,7 @@ private[sql] object SparkSqlSerializer2 {
out.writeShort(row.getShort(i))
}
- case IntegerType =>
+ case IntegerType | DateType =>
if (row.isNullAt(i)) {
out.writeByte(NULL)
} else {
@@ -245,7 +245,7 @@ private[sql] object SparkSqlSerializer2 {
out.writeInt(row.getInt(i))
}
- case LongType =>
+ case LongType | TimestampType =>
if (row.isNullAt(i)) {
out.writeByte(NULL)
} else {
@@ -269,55 +269,39 @@ private[sql] object SparkSqlSerializer2 {
out.writeDouble(row.getDouble(i))
}
- case decimal: DecimalType =>
+ case StringType =>
if (row.isNullAt(i)) {
out.writeByte(NULL)
} else {
out.writeByte(NOT_NULL)
- val value = row.apply(i).asInstanceOf[Decimal]
- val javaBigDecimal = value.toJavaBigDecimal
- // First, write out the unscaled value.
- val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray
+ val bytes = row.getAs[UTF8String](i).getBytes
out.writeInt(bytes.length)
out.write(bytes)
- // Then, write out the scale.
- out.writeInt(javaBigDecimal.scale())
}
- case DateType =>
- if (row.isNullAt(i)) {
- out.writeByte(NULL)
- } else {
- out.writeByte(NOT_NULL)
- out.writeInt(row.getAs[Int](i))
- }
-
- case TimestampType =>
- if (row.isNullAt(i)) {
- out.writeByte(NULL)
- } else {
- out.writeByte(NOT_NULL)
- out.writeLong(row.getAs[Long](i))
- }
-
- case StringType =>
+ case BinaryType =>
if (row.isNullAt(i)) {
out.writeByte(NULL)
} else {
out.writeByte(NOT_NULL)
- val bytes = row.getAs[UTF8String](i).getBytes
+ val bytes = row.getAs[Array[Byte]](i)
out.writeInt(bytes.length)
out.write(bytes)
}
- case BinaryType =>
+ case decimal: DecimalType =>
if (row.isNullAt(i)) {
out.writeByte(NULL)
} else {
out.writeByte(NOT_NULL)
- val bytes = row.getAs[Array[Byte]](i)
+ val value = row.apply(i).asInstanceOf[Decimal]
+ val javaBigDecimal = value.toJavaBigDecimal
+ // First, write out the unscaled value.
+ val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray
out.writeInt(bytes.length)
out.write(bytes)
+ // Then, write out the scale.
+ out.writeInt(javaBigDecimal.scale())
}
}
i += 1
@@ -364,14 +348,14 @@ private[sql] object SparkSqlSerializer2 {
mutableRow.setShort(i, in.readShort())
}
- case IntegerType =>
+ case IntegerType | DateType =>
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
} else {
mutableRow.setInt(i, in.readInt())
}
- case LongType =>
+ case LongType | TimestampType =>
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
} else {
@@ -392,53 +376,39 @@ private[sql] object SparkSqlSerializer2 {
mutableRow.setDouble(i, in.readDouble())
}
- case decimal: DecimalType =>
+ case StringType =>
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
} else {
- // First, read in the unscaled value.
val length = in.readInt()
val bytes = new Array[Byte](length)
in.readFully(bytes)
- val unscaledVal = new BigInteger(bytes)
- // Then, read the scale.
- val scale = in.readInt()
- // Finally, create the Decimal object and set it in the row.
- mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale)))
- }
-
- case DateType =>
- if (in.readByte() == NULL) {
- mutableRow.setNullAt(i)
- } else {
- mutableRow.update(i, in.readInt())
- }
-
- case TimestampType =>
- if (in.readByte() == NULL) {
- mutableRow.setNullAt(i)
- } else {
- mutableRow.update(i, in.readLong())
+ mutableRow.update(i, UTF8String.fromBytes(bytes))
}
- case StringType =>
+ case BinaryType =>
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
} else {
val length = in.readInt()
val bytes = new Array[Byte](length)
in.readFully(bytes)
- mutableRow.update(i, UTF8String.fromBytes(bytes))
+ mutableRow.update(i, bytes)
}
- case BinaryType =>
+ case decimal: DecimalType =>
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
} else {
+ // First, read in the unscaled value.
val length = in.readInt()
val bytes = new Array[Byte](length)
in.readFully(bytes)
- mutableRow.update(i, bytes)
+ val unscaledVal = new BigInteger(bytes)
+ // Then, read the scale.
+ val scale = in.readInt()
+ // Finally, create the Decimal object and set it in the row.
+ mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale)))
}
}
i += 1
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
index 0d96a1e807..df2a96dfeb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -198,19 +198,18 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
private[parquet] def writePrimitive(schema: DataType, value: Any): Unit = {
if (value != null) {
schema match {
- case StringType => writer.addBinary(
- Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes))
- case BinaryType => writer.addBinary(
- Binary.fromByteArray(value.asInstanceOf[Array[Byte]]))
- case IntegerType => writer.addInteger(value.asInstanceOf[Int])
+ case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean])
+ case ByteType => writer.addInteger(value.asInstanceOf[Byte])
case ShortType => writer.addInteger(value.asInstanceOf[Short])
+ case IntegerType | DateType => writer.addInteger(value.asInstanceOf[Int])
case LongType => writer.addLong(value.asInstanceOf[Long])
case TimestampType => writeTimestamp(value.asInstanceOf[Long])
- case ByteType => writer.addInteger(value.asInstanceOf[Byte])
- case DoubleType => writer.addDouble(value.asInstanceOf[Double])
case FloatType => writer.addFloat(value.asInstanceOf[Float])
- case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean])
- case DateType => writer.addInteger(value.asInstanceOf[Int])
+ case DoubleType => writer.addDouble(value.asInstanceOf[Double])
+ case StringType => writer.addBinary(
+ Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes))
+ case BinaryType => writer.addBinary(
+ Binary.fromByteArray(value.asInstanceOf[Array[Byte]]))
case d: DecimalType =>
if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) {
sys.error(s"Unsupported datatype $d, cannot write to consumer")
@@ -353,19 +352,18 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
record: InternalRow,
index: Int): Unit = {
ctype match {
+ case BooleanType => writer.addBoolean(record.getBoolean(index))
+ case ByteType => writer.addInteger(record.getByte(index))
+ case ShortType => writer.addInteger(record.getShort(index))
+ case IntegerType | DateType => writer.addInteger(record.getInt(index))
+ case LongType => writer.addLong(record.getLong(index))
+ case TimestampType => writeTimestamp(record.getLong(index))
+ case FloatType => writer.addFloat(record.getFloat(index))
+ case DoubleType => writer.addDouble(record.getDouble(index))
case StringType => writer.addBinary(
Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes))
case BinaryType => writer.addBinary(
Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]]))
- case IntegerType => writer.addInteger(record.getInt(index))
- case ShortType => writer.addInteger(record.getShort(index))
- case LongType => writer.addLong(record.getLong(index))
- case ByteType => writer.addInteger(record.getByte(index))
- case DoubleType => writer.addDouble(record.getDouble(index))
- case FloatType => writer.addFloat(record.getFloat(index))
- case BooleanType => writer.addBoolean(record.getBoolean(index))
- case DateType => writer.addInteger(record.getInt(index))
- case TimestampType => writeTimestamp(record.getLong(index))
case d: DecimalType =>
if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) {
sys.error(s"Unsupported datatype $d, cannot write to consumer")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
index 4d5199a140..e748bd7857 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
@@ -38,8 +38,8 @@ import org.apache.spark.sql.types._
private[parquet] object ParquetTypesConverter extends Logging {
def isPrimitiveType(ctype: DataType): Boolean = ctype match {
- case _: NumericType | BooleanType | StringType | BinaryType => true
- case _: DataType => false
+ case _: NumericType | BooleanType | DateType | TimestampType | StringType | BinaryType => true
+ case _ => false
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index 1f37455dd0..9bd7b221e9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -22,19 +22,20 @@ import org.apache.spark.sql.catalyst.expressions.InternalRow
import org.apache.spark.sql.types._
class ColumnStatsSuite extends SparkFunSuite {
+ testColumnStats(classOf[BooleanColumnStats], BOOLEAN, InternalRow(true, false, 0))
testColumnStats(classOf[ByteColumnStats], BYTE, InternalRow(Byte.MaxValue, Byte.MinValue, 0))
testColumnStats(classOf[ShortColumnStats], SHORT, InternalRow(Short.MaxValue, Short.MinValue, 0))
testColumnStats(classOf[IntColumnStats], INT, InternalRow(Int.MaxValue, Int.MinValue, 0))
+ testColumnStats(classOf[DateColumnStats], DATE, InternalRow(Int.MaxValue, Int.MinValue, 0))
testColumnStats(classOf[LongColumnStats], LONG, InternalRow(Long.MaxValue, Long.MinValue, 0))
+ testColumnStats(classOf[TimestampColumnStats], TIMESTAMP,
+ InternalRow(Long.MaxValue, Long.MinValue, 0))
testColumnStats(classOf[FloatColumnStats], FLOAT, InternalRow(Float.MaxValue, Float.MinValue, 0))
testColumnStats(classOf[DoubleColumnStats], DOUBLE,
InternalRow(Double.MaxValue, Double.MinValue, 0))
+ testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0))
testColumnStats(classOf[FixedDecimalColumnStats],
FIXED_DECIMAL(15, 10), InternalRow(null, null, 0))
- testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0))
- testColumnStats(classOf[DateColumnStats], DATE, InternalRow(Int.MaxValue, Int.MinValue, 0))
- testColumnStats(classOf[TimestampColumnStats], TIMESTAMP,
- InternalRow(Long.MaxValue, Long.MinValue, 0))
def testColumnStats[T <: AtomicType, U <: ColumnStats](
columnStatsClass: Class[U],
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index 6daddfb2c4..4d46a65705 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -36,9 +36,9 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
test("defaultSize") {
val checks = Map(
- INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4,
- FIXED_DECIMAL(15, 10) -> 8, BOOLEAN -> 1, STRING -> 8, DATE -> 4, TIMESTAMP -> 8,
- BINARY -> 16, GENERIC -> 16)
+ BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, DATE -> 4,
+ LONG -> 8, TIMESTAMP -> 8, FLOAT -> 4, DOUBLE -> 8,
+ STRING -> 8, BINARY -> 16, FIXED_DECIMAL(15, 10) -> 8, GENERIC -> 16)
checks.foreach { case (columnType, expectedSize) =>
assertResult(expectedSize, s"Wrong defaultSize for $columnType") {
@@ -60,27 +60,24 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
}
}
- checkActualSize(INT, Int.MaxValue, 4)
+ checkActualSize(BOOLEAN, true, 1)
+ checkActualSize(BYTE, Byte.MaxValue, 1)
checkActualSize(SHORT, Short.MaxValue, 2)
+ checkActualSize(INT, Int.MaxValue, 4)
+ checkActualSize(DATE, Int.MaxValue, 4)
checkActualSize(LONG, Long.MaxValue, 8)
- checkActualSize(BYTE, Byte.MaxValue, 1)
- checkActualSize(DOUBLE, Double.MaxValue, 8)
+ checkActualSize(TIMESTAMP, Long.MaxValue, 8)
checkActualSize(FLOAT, Float.MaxValue, 4)
- checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
- checkActualSize(BOOLEAN, true, 1)
+ checkActualSize(DOUBLE, Double.MaxValue, 8)
checkActualSize(STRING, UTF8String.fromString("hello"), 4 + "hello".getBytes("utf-8").length)
- checkActualSize(DATE, 0, 4)
- checkActualSize(TIMESTAMP, 0L, 8)
-
- val binary = Array.fill[Byte](4)(0: Byte)
- checkActualSize(BINARY, binary, 4 + 4)
+ checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4)
+ checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
val generic = Map(1 -> "a")
checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8)
}
- testNativeColumnType[BooleanType.type](
- BOOLEAN,
+ testNativeColumnType(BOOLEAN)(
(buffer: ByteBuffer, v: Boolean) => {
buffer.put((if (v) 1 else 0).toByte)
},
@@ -88,18 +85,23 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
buffer.get() == 1
})
- testNativeColumnType[IntegerType.type](INT, _.putInt(_), _.getInt)
+ testNativeColumnType(BYTE)(_.put(_), _.get)
+
+ testNativeColumnType(SHORT)(_.putShort(_), _.getShort)
+
+ testNativeColumnType(INT)(_.putInt(_), _.getInt)
+
+ testNativeColumnType(DATE)(_.putInt(_), _.getInt)
- testNativeColumnType[ShortType.type](SHORT, _.putShort(_), _.getShort)
+ testNativeColumnType(LONG)(_.putLong(_), _.getLong)
- testNativeColumnType[LongType.type](LONG, _.putLong(_), _.getLong)
+ testNativeColumnType(TIMESTAMP)(_.putLong(_), _.getLong)
- testNativeColumnType[ByteType.type](BYTE, _.put(_), _.get)
+ testNativeColumnType(FLOAT)(_.putFloat(_), _.getFloat)
- testNativeColumnType[DoubleType.type](DOUBLE, _.putDouble(_), _.getDouble)
+ testNativeColumnType(DOUBLE)(_.putDouble(_), _.getDouble)
- testNativeColumnType[DecimalType](
- FIXED_DECIMAL(15, 10),
+ testNativeColumnType(FIXED_DECIMAL(15, 10))(
(buffer: ByteBuffer, decimal: Decimal) => {
buffer.putLong(decimal.toUnscaledLong)
},
@@ -107,10 +109,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
Decimal(buffer.getLong(), 15, 10)
})
- testNativeColumnType[FloatType.type](FLOAT, _.putFloat(_), _.getFloat)
- testNativeColumnType[StringType.type](
- STRING,
+ testNativeColumnType(STRING)(
(buffer: ByteBuffer, string: UTF8String) => {
val bytes = string.getBytes
buffer.putInt(bytes.length)
@@ -197,8 +197,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
}
def testNativeColumnType[T <: AtomicType](
- columnType: NativeColumnType[T],
- putter: (ByteBuffer, T#InternalType) => Unit,
+ columnType: NativeColumnType[T])
+ (putter: (ByteBuffer, T#InternalType) => Unit,
getter: (ByteBuffer) => T#InternalType): Unit = {
testColumnType[T, T#InternalType](columnType, putter, getter)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
index 7c86eae3f7..d986133973 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
@@ -39,18 +39,18 @@ object ColumnarTestUtils {
}
(columnType match {
+ case BOOLEAN => Random.nextBoolean()
case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort
case INT => Random.nextInt()
+ case DATE => Random.nextInt()
case LONG => Random.nextLong()
+ case TIMESTAMP => Random.nextLong()
case FLOAT => Random.nextFloat()
case DOUBLE => Random.nextDouble()
- case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32)))
- case BOOLEAN => Random.nextBoolean()
case BINARY => randomBytes(Random.nextInt(32))
- case DATE => Random.nextInt()
- case TIMESTAMP => Random.nextLong()
+ case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
case _ =>
// Using a random one-element map instead of an arbitrary object
Map(Random.nextInt() -> Random.nextString(Random.nextInt(32)))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
index 2a6e0c3765..9eaa769846 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
@@ -42,9 +42,9 @@ class NullableColumnAccessorSuite extends SparkFunSuite {
import ColumnarTestUtils._
Seq(
- INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, FIXED_DECIMAL(15, 10), BINARY, GENERIC,
- DATE, TIMESTAMP
- ).foreach {
+ BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE,
+ STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC)
+ .foreach {
testNullableColumnAccessor(_)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
index cb4e9f1eb7..17e9ae464b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
@@ -38,9 +38,9 @@ class NullableColumnBuilderSuite extends SparkFunSuite {
import ColumnarTestUtils._
Seq(
- INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, FIXED_DECIMAL(15, 10), BINARY, GENERIC,
- DATE, TIMESTAMP
- ).foreach {
+ BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE,
+ STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC)
+ .foreach {
testNullableColumnBuilder(_)
}