aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src
diff options
context:
space:
mode:
authorKazuaki Ishizaki <ishizaki@jp.ibm.com>2016-12-02 12:30:13 +0800
committerWenchen Fan <wenchen@databricks.com>2016-12-02 12:30:13 +0800
commit38b9e69623c14a675b14639e8291f5d29d2a0bc3 (patch)
tree7dabeeb22f97923554f9fa155c6e7e22733ad060 /sql/catalyst/src
parent70c5549ee9588228d18a7b405c977cf591e2efd4 (diff)
downloadspark-38b9e69623c14a675b14639e8291f5d29d2a0bc3.tar.gz
spark-38b9e69623c14a675b14639e8291f5d29d2a0bc3.tar.bz2
spark-38b9e69623c14a675b14639e8291f5d29d2a0bc3.zip
[SPARK-18284][SQL] Make ExpressionEncoder.serializer.nullable precise
## What changes were proposed in this pull request? This PR makes `ExpressionEncoder.serializer.nullable` for flat encoder for a primitive type `false`. Since it is `true` for now, it is too conservative. While `ExpressionEncoder.schema` has correct information (e.g. `<IntegerType, false>`), `serializer.head.nullable` of `ExpressionEncoder`, which got from `encoderFor[T]`, is always false. It is too conservative. This is accomplished by checking whether a type is one of primitive types. If it is `true`, `nullable` should be `false`. ## How was this patch tested? Added new tests for encoder and dataframe Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com> Closes #15780 from kiszk/SPARK-18284.
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala24
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala19
6 files changed, 44 insertions, 19 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 04f0cfce88..7e8e4dab72 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -396,12 +396,14 @@ object JavaTypeInference {
case _ if mapType.isAssignableFrom(typeToken) =>
val (keyType, valueType) = mapKeyValueType(typeToken)
+
ExternalMapToCatalyst(
inputObject,
ObjectType(keyType.getRawType),
serializerFor(_, keyType),
ObjectType(valueType.getRawType),
- serializerFor(_, valueType)
+ serializerFor(_, valueType),
+ valueNullable = true
)
case other =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 0aa21b9347..6e20096901 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -498,7 +498,8 @@ object ScalaReflection extends ScalaReflection {
dataTypeFor(keyType),
serializerFor(_, keyType, keyPath),
dataTypeFor(valueType),
- serializerFor(_, valueType, valuePath))
+ serializerFor(_, valueType, valuePath),
+ valueNullable = !valueType.typeSymbol.asClass.isPrimitive)
case t if t <:< localTypeOf[String] =>
StaticInvoke(
@@ -590,7 +591,9 @@ object ScalaReflection extends ScalaReflection {
"cannot be used as field name\n" + walkedTypePath.mkString("\n"))
}
- val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
+ val fieldValue = Invoke(
+ AssertNotNull(inputObject, walkedTypePath), fieldName, dataTypeFor(fieldType),
+ returnNullable = !fieldType.typeSymbol.asClass.isPrimitive)
val clsName = getClassNameFromType(fieldType)
val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 9c4818db63..3757eccfa2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -60,7 +60,7 @@ object ExpressionEncoder {
val cls = mirror.runtimeClass(tpe)
val flat = !ScalaReflection.definedByConstructorParams(tpe)
- val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true)
+ val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = !cls.isPrimitive)
val nullSafeInput = if (flat) {
inputObject
} else {
@@ -71,10 +71,7 @@ object ExpressionEncoder {
val serializer = ScalaReflection.serializerFor[T](nullSafeInput)
val deserializer = ScalaReflection.deserializerFor[T]
- val schema = ScalaReflection.schemaFor[T] match {
- case ScalaReflection.Schema(s: StructType, _) => s
- case ScalaReflection.Schema(dt, nullable) => new StructType().add("value", dt, nullable)
- }
+ val schema = serializer.dataType
new ExpressionEncoder[T](
schema,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
index 6c75a7a502..2ca77e8394 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
@@ -74,7 +74,7 @@ case class ReferenceToExpressions(result: Expression, children: Seq[Expression])
ctx.addMutableState("boolean", classChildVarIsNull, "")
val classChildVar =
- LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType)
+ LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType, child.nullable)
val initCode = s"${classChildVar.value} = ${childGen.value};\n" +
s"${classChildVar.isNull} = ${childGen.isNull};"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index e517ec18eb..a8aa1e7255 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -171,15 +171,18 @@ case class StaticInvoke(
* @param arguments An optional list of expressions, whos evaluation will be passed to the function.
* @param propagateNull When true, and any of the arguments is null, null will be returned instead
* of calling the function.
+ * @param returnNullable When false, indicating the invoked method will always return
+ * non-null value.
*/
case class Invoke(
targetObject: Expression,
functionName: String,
dataType: DataType,
arguments: Seq[Expression] = Nil,
- propagateNull: Boolean = true) extends InvokeLike {
+ propagateNull: Boolean = true,
+ returnNullable : Boolean = true) extends InvokeLike {
- override def nullable: Boolean = true
+ override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable
override def children: Seq[Expression] = targetObject +: arguments
override def eval(input: InternalRow): Any =
@@ -405,13 +408,15 @@ case class WrapOption(child: Expression, optType: DataType)
* A place holder for the loop variable used in [[MapObjects]]. This should never be constructed
* manually, but will instead be passed into the provided lambda function.
*/
-case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends LeafExpression
+case class LambdaVariable(
+ value: String,
+ isNull: String,
+ dataType: DataType,
+ nullable: Boolean = true) extends LeafExpression
with Unevaluable with NonSQLExpression {
- override def nullable: Boolean = true
-
override def genCode(ctx: CodegenContext): ExprCode = {
- ExprCode(code = "", value = value, isNull = isNull)
+ ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false")
}
}
@@ -592,7 +597,8 @@ object ExternalMapToCatalyst {
keyType: DataType,
keyConverter: Expression => Expression,
valueType: DataType,
- valueConverter: Expression => Expression): ExternalMapToCatalyst = {
+ valueConverter: Expression => Expression,
+ valueNullable: Boolean): ExternalMapToCatalyst = {
val id = curId.getAndIncrement()
val keyName = "ExternalMapToCatalyst_key" + id
val valueName = "ExternalMapToCatalyst_value" + id
@@ -601,11 +607,11 @@ object ExternalMapToCatalyst {
ExternalMapToCatalyst(
keyName,
keyType,
- keyConverter(LambdaVariable(keyName, "false", keyType)),
+ keyConverter(LambdaVariable(keyName, "false", keyType, false)),
valueName,
valueIsNull,
valueType,
- valueConverter(LambdaVariable(valueName, valueIsNull, valueType)),
+ valueConverter(LambdaVariable(valueName, valueIsNull, valueType, valueNullable)),
inputMap
)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 4d896c2e38..080f11b769 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -24,7 +24,7 @@ import java.util.Arrays
import scala.collection.mutable.ArrayBuffer
import scala.reflect.runtime.universe.TypeTag
-import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
import org.apache.spark.sql.catalyst.analysis.AnalysisTest
import org.apache.spark.sql.catalyst.dsl.plans._
@@ -300,6 +300,11 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
encodeDecodeTest(
ReferenceValueClass(ReferenceValueClass.Container(1)), "reference value class")
+ encodeDecodeTest(Option(31), "option of int")
+ encodeDecodeTest(Option.empty[Int], "empty option of int")
+ encodeDecodeTest(Option("abc"), "option of string")
+ encodeDecodeTest(Option.empty[String], "empty option of string")
+
productTest(("UDT", new ExamplePoint(0.1, 0.2)))
test("nullable of encoder schema") {
@@ -338,6 +343,18 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
}
}
+ test("nullable of encoder serializer") {
+ def checkNullable[T: Encoder](nullable: Boolean): Unit = {
+ assert(encoderFor[T].serializer.forall(_.nullable === nullable))
+ }
+
+ // test for flat encoders
+ checkNullable[Int](false)
+ checkNullable[Option[Int]](true)
+ checkNullable[java.lang.Integer](true)
+ checkNullable[String](true)
+ }
+
test("null check for map key") {
val encoder = ExpressionEncoder[Map[String, Int]]()
val e = intercept[RuntimeException](encoder.toRow(Map(("a", 1), (null, 2))))