aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-11-23 10:15:40 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-23 10:15:40 -0800
commitf2996e0d12eeb989b1bfa51a3f6fa54ce1ed4fca (patch)
tree3e4878bae0fda35536360a053e993db2cb1ea624
parent1a5baaa6517872b9a4fd6cd41c4b2cf1e390f6d1 (diff)
downloadspark-f2996e0d12eeb989b1bfa51a3f6fa54ce1ed4fca.tar.gz
spark-f2996e0d12eeb989b1bfa51a3f6fa54ce1ed4fca.tar.bz2
spark-f2996e0d12eeb989b1bfa51a3f6fa54ce1ed4fca.zip
[SPARK-11921][SQL] fix `nullable` of encoder schema
Author: Wenchen Fan <wenchen@databricks.com> Closes #9906 from cloud-fan/nullable.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala15
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala38
2 files changed, 50 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 6eeba1442c..7bc9aed0b2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -54,8 +54,13 @@ object ExpressionEncoder {
val toRowExpression = ScalaReflection.extractorsFor[T](inputObject)
val fromRowExpression = ScalaReflection.constructorFor[T]
+ val schema = ScalaReflection.schemaFor[T] match {
+ case ScalaReflection.Schema(s: StructType, _) => s
+ case ScalaReflection.Schema(dt, nullable) => new StructType().add("value", dt, nullable)
+ }
+
new ExpressionEncoder[T](
- toRowExpression.dataType,
+ schema,
flat,
toRowExpression.flatten,
fromRowExpression,
@@ -71,7 +76,13 @@ object ExpressionEncoder {
encoders.foreach(_.assertUnresolved())
val schema = StructType(encoders.zipWithIndex.map {
- case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
+ case (e, i) =>
+ val (dataType, nullable) = if (e.flat) {
+ e.schema.head.dataType -> e.schema.head.nullable
+ } else {
+ e.schema -> true
+ }
+ StructField(s"_${i + 1}", dataType, nullable)
})
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 76459b34a4..d6ca138672 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.Encoders
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
-import org.apache.spark.sql.types.ArrayType
+import org.apache.spark.sql.types.{StructType, ArrayType}
case class RepeatedStruct(s: Seq[PrimitiveData])
@@ -238,6 +238,42 @@ class ExpressionEncoderSuite extends SparkFunSuite {
ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc))
}
+ test("nullable of encoder schema") {
+ def checkNullable[T: ExpressionEncoder](nullable: Boolean*): Unit = {
+ assert(implicitly[ExpressionEncoder[T]].schema.map(_.nullable) === nullable.toSeq)
+ }
+
+ // test for flat encoders
+ checkNullable[Int](false)
+ checkNullable[Option[Int]](true)
+ checkNullable[java.lang.Integer](true)
+ checkNullable[String](true)
+
+ // test for product encoders
+ checkNullable[(String, Int)](true, false)
+ checkNullable[(Int, java.lang.Long)](false, true)
+
+ // test for nested product encoders
+ {
+ val schema = ExpressionEncoder[(Int, (String, Int))].schema
+ assert(schema(0).nullable === false)
+ assert(schema(1).nullable === true)
+ assert(schema(1).dataType.asInstanceOf[StructType](0).nullable === true)
+ assert(schema(1).dataType.asInstanceOf[StructType](1).nullable === false)
+ }
+
+ // test for tupled encoders
+ {
+ val schema = ExpressionEncoder.tuple(
+ ExpressionEncoder[Int],
+ ExpressionEncoder[(String, Int)]).schema
+ assert(schema(0).nullable === false)
+ assert(schema(1).nullable === true)
+ assert(schema(1).dataType.asInstanceOf[StructType](0).nullable === true)
+ assert(schema(1).dataType.asInstanceOf[StructType](1).nullable === false)
+ }
+ }
+
private val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap()
outers.put(getClass.getName, this)
private def encodeDecodeTest[T : ExpressionEncoder](