aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@appier.com>2015-11-22 10:36:47 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-22 10:36:47 -0800
commit426004a9c9a864f90494d08601e6974709091a56 (patch)
tree03e6833e66a98a6d327cacaee3c3dedf095877e6
parentff442bbcffd4f93cfcc2f76d160011e725d2fb3f (diff)
downloadspark-426004a9c9a864f90494d08601e6974709091a56.tar.gz
spark-426004a9c9a864f90494d08601e6974709091a56.tar.bz2
spark-426004a9c9a864f90494d08601e6974709091a56.zip
[SPARK-11908][SQL] Add NullType support to RowEncoder
JIRA: https://issues.apache.org/jira/browse/SPARK-11908 We should add NullType support to RowEncoder. Author: Liang-Chi Hsieh <viirya@appier.com> Closes #9891 from viirya/rowencoder-nulltype.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala3
3 files changed, 9 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 4cda4824ac..fa553e7c53 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -48,7 +48,7 @@ object RowEncoder {
private def extractorsFor(
inputObject: Expression,
inputType: DataType): Expression = inputType match {
- case BooleanType | ByteType | ShortType | IntegerType | LongType |
+ case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType => inputObject
case udt: UserDefinedType[_] =>
@@ -143,6 +143,7 @@ object RowEncoder {
case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]])
case _: StructType => ObjectType(classOf[Row])
case udt: UserDefinedType[_] => ObjectType(udt.userClass)
+ case _: NullType => ObjectType(classOf[java.lang.Object])
}
private def constructorFor(schema: StructType): Expression = {
@@ -158,7 +159,7 @@ object RowEncoder {
}
private def constructorFor(input: Expression): Expression = input.dataType match {
- case BooleanType | ByteType | ShortType | IntegerType | LongType |
+ case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType => input
case udt: UserDefinedType[_] =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index ef7399e019..82317d3385 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -369,6 +369,9 @@ case class MapObjects(
private lazy val completeFunction = function(loopAttribute)
private def itemAccessorMethod(dataType: DataType): String => String = dataType match {
+ case NullType =>
+ val nullTypeClassName = NullType.getClass.getName + ".MODULE$"
+ (i: String) => s".get($i, $nullTypeClassName)"
case IntegerType => (i: String) => s".getInt($i)"
case LongType => (i: String) => s".getLong($i)"
case FloatType => (i: String) => s".getFloat($i)"
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 46c6e0d98d..0ea51ece4b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -80,11 +80,13 @@ class RowEncoderSuite extends SparkFunSuite {
private val structOfString = new StructType().add("str", StringType)
private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false)
private val arrayOfString = ArrayType(StringType)
+ private val arrayOfNull = ArrayType(NullType)
private val mapOfString = MapType(StringType, StringType)
private val arrayOfUDT = ArrayType(new ExamplePointUDT, false)
encodeDecodeTest(
new StructType()
+ .add("null", NullType)
.add("boolean", BooleanType)
.add("byte", ByteType)
.add("short", ShortType)
@@ -101,6 +103,7 @@ class RowEncoderSuite extends SparkFunSuite {
encodeDecodeTest(
new StructType()
+ .add("arrayOfNull", arrayOfNull)
.add("arrayOfString", arrayOfString)
.add("arrayOfArrayOfString", ArrayType(arrayOfString))
.add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType)))