From f2996e0d12eeb989b1bfa51a3f6fa54ce1ed4fca Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 Nov 2015 10:15:40 -0800 Subject: [SPARK-11921][SQL] fix `nullable` of encoder schema Author: Wenchen Fan Closes #9906 from cloud-fan/nullable. --- .../sql/catalyst/encoders/ExpressionEncoder.scala | 15 +++++++-- .../catalyst/encoders/ExpressionEncoderSuite.scala | 38 +++++++++++++++++++++- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 6eeba1442c..7bc9aed0b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -54,8 +54,13 @@ object ExpressionEncoder { val toRowExpression = ScalaReflection.extractorsFor[T](inputObject) val fromRowExpression = ScalaReflection.constructorFor[T] + val schema = ScalaReflection.schemaFor[T] match { + case ScalaReflection.Schema(s: StructType, _) => s + case ScalaReflection.Schema(dt, nullable) => new StructType().add("value", dt, nullable) + } + new ExpressionEncoder[T]( - toRowExpression.dataType, + schema, flat, toRowExpression.flatten, fromRowExpression, @@ -71,7 +76,13 @@ object ExpressionEncoder { encoders.foreach(_.assertUnresolved()) val schema = StructType(encoders.zipWithIndex.map { - case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) + case (e, i) => + val (dataType, nullable) = if (e.flat) { + e.schema.head.dataType -> e.schema.head.nullable + } else { + e.schema -> true + } + StructField(s"_${i + 1}", dataType, nullable) }) val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 76459b34a4..d6ca138672 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} -import org.apache.spark.sql.types.ArrayType +import org.apache.spark.sql.types.{StructType, ArrayType} case class RepeatedStruct(s: Seq[PrimitiveData]) @@ -238,6 +238,42 @@ class ExpressionEncoderSuite extends SparkFunSuite { ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) } + test("nullable of encoder schema") { + def checkNullable[T: ExpressionEncoder](nullable: Boolean*): Unit = { + assert(implicitly[ExpressionEncoder[T]].schema.map(_.nullable) === nullable.toSeq) + } + + // test for flat encoders + checkNullable[Int](false) + checkNullable[Option[Int]](true) + checkNullable[java.lang.Integer](true) + checkNullable[String](true) + + // test for product encoders + checkNullable[(String, Int)](true, false) + checkNullable[(Int, java.lang.Long)](false, true) + + // test for nested product encoders + { + val schema = ExpressionEncoder[(Int, (String, Int))].schema + assert(schema(0).nullable === false) + assert(schema(1).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](0).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](1).nullable === false) + } + + // test for tupled encoders + { + val schema = ExpressionEncoder.tuple( + ExpressionEncoder[Int], + ExpressionEncoder[(String, Int)]).schema + assert(schema(0).nullable === false) + assert(schema(1).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](0).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](1).nullable === false) + } + } + private val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() outers.put(getClass.getName, this) private def encodeDecodeTest[T : ExpressionEncoder]( -- cgit v1.2.3