aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-12-30 10:56:08 -0800
committerMichael Armbrust <michael@databricks.com>2015-12-30 10:56:08 -0800
commitaa48164a43bd9ed9eab53fcacbed92819e84eaf7 (patch)
tree195045d145396f3ae82a0031fc0c6a5919ef755d
parent932cf44248e067ee7cae6fef79ddf2ab9b1c36d8 (diff)
downloadspark-aa48164a43bd9ed9eab53fcacbed92819e84eaf7.tar.gz
spark-aa48164a43bd9ed9eab53fcacbed92819e84eaf7.tar.bz2
spark-aa48164a43bd9ed9eab53fcacbed92819e84eaf7.zip
[SPARK-12495][SQL] use true as default value for propagateNull in NewInstance
Most of cases we should propagate null when call `NewInstance`, and so far there is only one case we should stop null propagation: create product/java bean. So I think it makes more sense to propagate null by dafault. This also fixes a bug when encode null array/map, which is firstly discovered in https://github.com/apache/spark/pull/10401 Author: Wenchen Fan <wenchen@databricks.com> Closes #10443 from cloud-fan/encoder.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala16
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala16
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala12
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala24
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala3
7 files changed, 38 insertions, 37 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index a1500cbc30..ed153d1f88 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -178,19 +178,19 @@ object JavaTypeInference {
case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath
case c if c == classOf[java.lang.Short] =>
- NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+ NewInstance(c, getPath :: Nil, ObjectType(c))
case c if c == classOf[java.lang.Integer] =>
- NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+ NewInstance(c, getPath :: Nil, ObjectType(c))
case c if c == classOf[java.lang.Long] =>
- NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+ NewInstance(c, getPath :: Nil, ObjectType(c))
case c if c == classOf[java.lang.Double] =>
- NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+ NewInstance(c, getPath :: Nil, ObjectType(c))
case c if c == classOf[java.lang.Byte] =>
- NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+ NewInstance(c, getPath :: Nil, ObjectType(c))
case c if c == classOf[java.lang.Float] =>
- NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+ NewInstance(c, getPath :: Nil, ObjectType(c))
case c if c == classOf[java.lang.Boolean] =>
- NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+ NewInstance(c, getPath :: Nil, ObjectType(c))
case c if c == classOf[java.sql.Date] =>
StaticInvoke(
@@ -298,7 +298,7 @@ object JavaTypeInference {
p.getWriteMethod.getName -> setter
}.toMap
- val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other))
+ val newInstance = NewInstance(other, Nil, ObjectType(other), propagateNull = false)
val result = InitializeJavaBean(newInstance, setters)
if (path.nonEmpty) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 8a22b37d07..9784c96966 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -189,37 +189,37 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[java.lang.Integer] =>
val boxedType = classOf[java.lang.Integer]
val objectType = ObjectType(boxedType)
- NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+ NewInstance(boxedType, getPath :: Nil, objectType)
case t if t <:< localTypeOf[java.lang.Long] =>
val boxedType = classOf[java.lang.Long]
val objectType = ObjectType(boxedType)
- NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+ NewInstance(boxedType, getPath :: Nil, objectType)
case t if t <:< localTypeOf[java.lang.Double] =>
val boxedType = classOf[java.lang.Double]
val objectType = ObjectType(boxedType)
- NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+ NewInstance(boxedType, getPath :: Nil, objectType)
case t if t <:< localTypeOf[java.lang.Float] =>
val boxedType = classOf[java.lang.Float]
val objectType = ObjectType(boxedType)
- NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+ NewInstance(boxedType, getPath :: Nil, objectType)
case t if t <:< localTypeOf[java.lang.Short] =>
val boxedType = classOf[java.lang.Short]
val objectType = ObjectType(boxedType)
- NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+ NewInstance(boxedType, getPath :: Nil, objectType)
case t if t <:< localTypeOf[java.lang.Byte] =>
val boxedType = classOf[java.lang.Byte]
val objectType = ObjectType(boxedType)
- NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+ NewInstance(boxedType, getPath :: Nil, objectType)
case t if t <:< localTypeOf[java.lang.Boolean] =>
val boxedType = classOf[java.lang.Boolean]
val objectType = ObjectType(boxedType)
- NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+ NewInstance(boxedType, getPath :: Nil, objectType)
case t if t <:< localTypeOf[java.sql.Date] =>
StaticInvoke(
@@ -349,7 +349,7 @@ object ScalaReflection extends ScalaReflection {
}
}
- val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls))
+ val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)
if (path.nonEmpty) {
expressions.If(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 7a4401cf58..ad4beda9c4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -133,7 +133,7 @@ object ExpressionEncoder {
}
val fromRowExpression =
- NewInstance(cls, fromRowExpressions, propagateNull = false, ObjectType(cls))
+ NewInstance(cls, fromRowExpressions, ObjectType(cls), propagateNull = false)
new ExpressionEncoder[Any](
schema,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 63bdf05ca7..6f3d5ba84c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -55,7 +55,6 @@ object RowEncoder {
val obj = NewInstance(
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
Nil,
- false,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)
@@ -166,7 +165,6 @@ object RowEncoder {
val obj = NewInstance(
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
Nil,
- false,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index d40cd96905..fb404c12d5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -165,7 +165,7 @@ case class Invoke(
${obj.code}
${argGen.map(_.code).mkString("\n")}
- boolean ${ev.isNull} = ${obj.value} == null;
+ boolean ${ev.isNull} = ${obj.isNull};
$javaType ${ev.value} =
${ev.isNull} ?
${ctx.defaultValue(dataType)} : ($javaType) $value;
@@ -178,8 +178,8 @@ object NewInstance {
def apply(
cls: Class[_],
arguments: Seq[Expression],
- propagateNull: Boolean = false,
- dataType: DataType): NewInstance =
+ dataType: DataType,
+ propagateNull: Boolean = true): NewInstance =
new NewInstance(cls, arguments, propagateNull, dataType, None)
}
@@ -231,7 +231,7 @@ case class NewInstance(
s"new $className($argString)"
}
- if (propagateNull) {
+ if (propagateNull && argGen.nonEmpty) {
val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})"
s"""
@@ -248,8 +248,8 @@ case class NewInstance(
s"""
$setup
- $javaType ${ev.value} = $constructorCall;
- final boolean ${ev.isNull} = ${ev.value} == null;
+ final $javaType ${ev.value} = $constructorCall;
+ final boolean ${ev.isNull} = false;
"""
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index 764ffdc094..bc36a55ae0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -46,8 +46,8 @@ class EncoderResolutionSuite extends PlanTest {
toExternalString('a.string),
AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long")
),
- false,
- ObjectType(cls))
+ ObjectType(cls),
+ propagateNull = false)
compareExpressions(fromRowExpr, expected)
}
@@ -60,8 +60,8 @@ class EncoderResolutionSuite extends PlanTest {
toExternalString('a.int.cast(StringType)),
AssertNotNull('b.long, cls.getName, "b", "Long")
),
- false,
- ObjectType(cls))
+ ObjectType(cls),
+ propagateNull = false)
compareExpressions(fromRowExpr, expected)
}
}
@@ -88,11 +88,11 @@ class EncoderResolutionSuite extends PlanTest {
AssertNotNull(
GetStructField('b.struct('a.int, 'b.long), 1, Some("b")),
innerCls.getName, "b", "Long")),
- false,
- ObjectType(innerCls))
+ ObjectType(innerCls),
+ propagateNull = false)
)),
- false,
- ObjectType(cls))
+ ObjectType(cls),
+ propagateNull = false)
compareExpressions(fromRowExpr, expected)
}
@@ -114,11 +114,11 @@ class EncoderResolutionSuite extends PlanTest {
AssertNotNull(
GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType),
cls.getName, "b", "Long")),
- false,
- ObjectType(cls)),
+ ObjectType(cls),
+ propagateNull = false),
'b.int.cast(LongType)),
- false,
- ObjectType(classOf[Tuple2[_, _]]))
+ ObjectType(classOf[Tuple2[_, _]]),
+ propagateNull = false)
compareExpressions(fromRowExpr, expected)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 7233e0f1b5..666699e18d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -128,6 +128,9 @@ class ExpressionEncoderSuite extends SparkFunSuite {
encodeDecodeTest(Map(1 -> "a", 2 -> null), "map with null")
encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), "map of map")
+ encodeDecodeTest(Tuple1[Seq[Int]](null), "null seq in tuple")
+ encodeDecodeTest(Tuple1[Map[String, String]](null), "null map in tuple")
+
// Kryo encoders
encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String]))
encodeDecodeTest(new KryoSerializable(15), "kryo object")(