From 38b9e69623c14a675b14639e8291f5d29d2a0bc3 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 2 Dec 2016 12:30:13 +0800 Subject: [SPARK-18284][SQL] Make ExpressionEncoder.serializer.nullable precise ## What changes were proposed in this pull request? This PR makes `ExpressionEncoder.serializer.nullable` for flat encoder for a primitive type `false`. Since it is `true` for now, it is too conservative. While `ExpressionEncoder.schema` has correct information (e.g. ``), `serializer.head.nullable` of `ExpressionEncoder`, which got from `encoderFor[T]`, is always false. It is too conservative. This is accomplished by checking whether a type is one of primitive types. If it is `true`, `nullable` should be `false`. ## How was this patch tested? Added new tests for encoder and dataframe Author: Kazuaki Ishizaki Closes #15780 from kiszk/SPARK-18284. --- .../spark/sql/catalyst/JavaTypeInference.scala | 4 +++- .../spark/sql/catalyst/ScalaReflection.scala | 7 +++++-- .../sql/catalyst/encoders/ExpressionEncoder.scala | 7 ++----- .../expressions/ReferenceToExpressions.scala | 2 +- .../sql/catalyst/expressions/objects/objects.scala | 24 ++++++++++++++-------- .../catalyst/encoders/ExpressionEncoderSuite.scala | 19 ++++++++++++++++- 6 files changed, 44 insertions(+), 19 deletions(-) (limited to 'sql/catalyst/src') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 04f0cfce88..7e8e4dab72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -396,12 +396,14 @@ object JavaTypeInference { case _ if mapType.isAssignableFrom(typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) + ExternalMapToCatalyst( inputObject, ObjectType(keyType.getRawType), serializerFor(_, keyType), ObjectType(valueType.getRawType), - serializerFor(_, valueType) + serializerFor(_, valueType), + valueNullable = true ) case other => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 0aa21b9347..6e20096901 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -498,7 +498,8 @@ object ScalaReflection extends ScalaReflection { dataTypeFor(keyType), serializerFor(_, keyType, keyPath), dataTypeFor(valueType), - serializerFor(_, valueType, valuePath)) + serializerFor(_, valueType, valuePath), + valueNullable = !valueType.typeSymbol.asClass.isPrimitive) case t if t <:< localTypeOf[String] => StaticInvoke( @@ -590,7 +591,9 @@ object ScalaReflection extends ScalaReflection { "cannot be used as field name\n" + walkedTypePath.mkString("\n")) } - val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) + val fieldValue = Invoke( + AssertNotNull(inputObject, walkedTypePath), fieldName, dataTypeFor(fieldType), + returnNullable = !fieldType.typeSymbol.asClass.isPrimitive) val clsName = getClassNameFromType(fieldType) val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 9c4818db63..3757eccfa2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -60,7 +60,7 @@ object ExpressionEncoder { val cls = mirror.runtimeClass(tpe) val flat = !ScalaReflection.definedByConstructorParams(tpe) - val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true) + val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = !cls.isPrimitive) val nullSafeInput = if (flat) { inputObject } else { @@ -71,10 +71,7 @@ object ExpressionEncoder { val serializer = ScalaReflection.serializerFor[T](nullSafeInput) val deserializer = ScalaReflection.deserializerFor[T] - val schema = ScalaReflection.schemaFor[T] match { - case ScalaReflection.Schema(s: StructType, _) => s - case ScalaReflection.Schema(dt, nullable) => new StructType().add("value", dt, nullable) - } + val schema = serializer.dataType new ExpressionEncoder[T]( schema, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala index 6c75a7a502..2ca77e8394 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala @@ -74,7 +74,7 @@ case class ReferenceToExpressions(result: Expression, children: Seq[Expression]) ctx.addMutableState("boolean", classChildVarIsNull, "") val classChildVar = - LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType) + LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType, child.nullable) val initCode = s"${classChildVar.value} = ${childGen.value};\n" + s"${classChildVar.isNull} = ${childGen.isNull};" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index e517ec18eb..a8aa1e7255 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -171,15 +171,18 @@ case class StaticInvoke( * @param arguments An optional list of expressions, whos evaluation will be passed to the function. * @param propagateNull When true, and any of the arguments is null, null will be returned instead * of calling the function. + * @param returnNullable When false, indicating the invoked method will always return + * non-null value. */ case class Invoke( targetObject: Expression, functionName: String, dataType: DataType, arguments: Seq[Expression] = Nil, - propagateNull: Boolean = true) extends InvokeLike { + propagateNull: Boolean = true, + returnNullable : Boolean = true) extends InvokeLike { - override def nullable: Boolean = true + override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable override def children: Seq[Expression] = targetObject +: arguments override def eval(input: InternalRow): Any = @@ -405,13 +408,15 @@ case class WrapOption(child: Expression, optType: DataType) * A place holder for the loop variable used in [[MapObjects]]. This should never be constructed * manually, but will instead be passed into the provided lambda function. */ -case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends LeafExpression +case class LambdaVariable( + value: String, + isNull: String, + dataType: DataType, + nullable: Boolean = true) extends LeafExpression with Unevaluable with NonSQLExpression { - override def nullable: Boolean = true - override def genCode(ctx: CodegenContext): ExprCode = { - ExprCode(code = "", value = value, isNull = isNull) + ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false") } } @@ -592,7 +597,8 @@ object ExternalMapToCatalyst { keyType: DataType, keyConverter: Expression => Expression, valueType: DataType, - valueConverter: Expression => Expression): ExternalMapToCatalyst = { + valueConverter: Expression => Expression, + valueNullable: Boolean): ExternalMapToCatalyst = { val id = curId.getAndIncrement() val keyName = "ExternalMapToCatalyst_key" + id val valueName = "ExternalMapToCatalyst_value" + id @@ -601,11 +607,11 @@ object ExternalMapToCatalyst { ExternalMapToCatalyst( keyName, keyType, - keyConverter(LambdaVariable(keyName, "false", keyType)), + keyConverter(LambdaVariable(keyName, "false", keyType, false)), valueName, valueIsNull, valueType, - valueConverter(LambdaVariable(valueName, valueIsNull, valueType)), + valueConverter(LambdaVariable(valueName, valueIsNull, valueType, valueNullable)), inputMap ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 4d896c2e38..080f11b769 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -24,7 +24,7 @@ import java.util.Arrays import scala.collection.mutable.ArrayBuffer import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.sql.Encoders +import org.apache.spark.sql.{Encoder, Encoders} import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.dsl.plans._ @@ -300,6 +300,11 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { encodeDecodeTest( ReferenceValueClass(ReferenceValueClass.Container(1)), "reference value class") + encodeDecodeTest(Option(31), "option of int") + encodeDecodeTest(Option.empty[Int], "empty option of int") + encodeDecodeTest(Option("abc"), "option of string") + encodeDecodeTest(Option.empty[String], "empty option of string") + productTest(("UDT", new ExamplePoint(0.1, 0.2))) test("nullable of encoder schema") { @@ -338,6 +343,18 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { } } + test("nullable of encoder serializer") { + def checkNullable[T: Encoder](nullable: Boolean): Unit = { + assert(encoderFor[T].serializer.forall(_.nullable === nullable)) + } + + // test for flat encoders + checkNullable[Int](false) + checkNullable[Option[Int]](true) + checkNullable[java.lang.Integer](true) + checkNullable[String](true) + } + test("null check for map key") { val encoder = ExpressionEncoder[Map[String, Int]]() val e = intercept[RuntimeException](encoder.toRow(Map(("a", 1), (null, 2)))) -- cgit v1.2.3