aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-02-08 12:06:00 -0800
committerMichael Armbrust <michael@databricks.com>2016-02-08 12:06:00 -0800
commit8e4d15f70713e1aaaa96dfb3ea4ccc5bb08eb2ce (patch)
treed03e7c60c6d08a606331bafa139dcd7cfa443e1b /sql
parent06f0df6df204c4722ff8a6bf909abaa32a715c41 (diff)
downloadspark-8e4d15f70713e1aaaa96dfb3ea4ccc5bb08eb2ce.tar.gz
spark-8e4d15f70713e1aaaa96dfb3ea4ccc5bb08eb2ce.tar.bz2
spark-8e4d15f70713e1aaaa96dfb3ea4ccc5bb08eb2ce.zip
[SPARK-13101][SQL] nullability of array type element should not fail analysis of encoder
nullability should only be considered as an optimization rather than part of the type system, so instead of failing analysis for mismatch nullability, we should pass analysis and add runtime null check. Author: Wenchen Fan <wenchen@databricks.com> Closes #11035 from cloud-fan/ignore-nullability.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala29
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala20
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala107
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala4
7 files changed, 64 insertions, 104 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 3c3717d504..59ee41d02f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -292,7 +292,7 @@ object JavaTypeInference {
val setter = if (nullable) {
constructor
} else {
- AssertNotNull(constructor, other.getName, fieldName, fieldType.toString)
+ AssertNotNull(constructor, Seq("currently no type path record in java"))
}
p.getWriteMethod.getName -> setter
}.toMap
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index e5811efb43..02cb2d9a2b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -249,6 +249,8 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
+
+ // TODO: add runtime null check for primitive array
val primitiveMethod = elementType match {
case t if t <:< definitions.IntTpe => Some("toIntArray")
case t if t <:< definitions.LongTpe => Some("toLongArray")
@@ -276,22 +278,29 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
+ val Schema(dataType, nullable) = schemaFor(elementType)
val className = getClassNameFromType(elementType)
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
- val arrayData =
- Invoke(
- MapObjects(
- p => constructorFor(elementType, Some(p), newTypePath),
- getPath,
- schemaFor(elementType).dataType),
- "array",
- ObjectType(classOf[Array[Any]]))
+
+ val mapFunction: Expression => Expression = p => {
+ val converter = constructorFor(elementType, Some(p), newTypePath)
+ if (nullable) {
+ converter
+ } else {
+ AssertNotNull(converter, newTypePath)
+ }
+ }
+
+ val array = Invoke(
+ MapObjects(mapFunction, getPath, dataType),
+ "array",
+ ObjectType(classOf[Array[Any]]))
StaticInvoke(
scala.collection.mutable.WrappedArray.getClass,
ObjectType(classOf[Seq[_]]),
"make",
- arrayData :: Nil)
+ array :: Nil)
case t if t <:< localTypeOf[Map[_, _]] =>
// TODO: add walked type path for map
@@ -343,7 +352,7 @@ object ScalaReflection extends ScalaReflection {
newTypePath)
if (!nullable) {
- AssertNotNull(constructor, t.toString, fieldName, fieldType.toString)
+ AssertNotNull(constructor, newTypePath)
} else {
constructor
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index cb228cf52b..4d53b232d5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1426,7 +1426,7 @@ object ResolveUpCast extends Rule[LogicalPlan] {
fail(child, DateType, walkedTypePath)
case (StringType, to: NumericType) =>
fail(child, to, walkedTypePath)
- case _ => Cast(child, dataType)
+ case _ => Cast(child, dataType.asNullable)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 79fe0033b7..fef6825b2d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -365,7 +365,7 @@ object MapObjects {
* to handle collection elements.
* @param inputData An expression that when evaluted returns a collection object.
*/
-case class MapObjects(
+case class MapObjects private(
loopVar: LambdaVariable,
lambdaFunction: Expression,
inputData: Expression) extends Expression {
@@ -637,8 +637,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
* `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all
* non-null `s`, `s.i` can't be null.
*/
-case class AssertNotNull(
- child: Expression, parentType: String, fieldName: String, fieldType: String)
+case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
extends UnaryExpression {
override def dataType: DataType = child.dataType
@@ -651,6 +650,14 @@ case class AssertNotNull(
override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val childGen = child.gen(ctx)
+ val errMsg = "Null value appeared in non-nullable field:" +
+ walkedTypePath.mkString("\n", "\n", "\n") +
+ "If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
+ "please try to use scala.Option[_] or other nullable types " +
+ "(e.g. java.lang.Integer instead of int/scala.Int)."
+ val idx = ctx.references.length
+ ctx.references += errMsg
+
ev.isNull = "false"
ev.value = childGen.value
@@ -658,12 +665,7 @@ case class AssertNotNull(
${childGen.code}
if (${childGen.isNull}) {
- throw new RuntimeException(
- "Null value appeared in non-nullable field $parentType.$fieldName of type $fieldType. " +
- "If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
- "please try to use scala.Option[_] or other nullable types " +
- "(e.g. java.lang.Integer instead of int/scala.Int)."
- );
+ throw new RuntimeException((String) references[$idx]);
}
"""
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index 92a68a4dba..8b02b63c6c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -21,9 +21,11 @@ import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
case class StringLongClass(a: String, b: Long)
@@ -32,94 +34,49 @@ case class StringIntClass(a: String, b: Int)
case class ComplexClass(a: Long, b: StringLongClass)
class EncoderResolutionSuite extends PlanTest {
+ private val str = UTF8String.fromString("hello")
+
test("real type doesn't match encoder schema but they are compatible: product") {
val encoder = ExpressionEncoder[StringLongClass]
- val cls = classOf[StringLongClass]
-
- {
- val attrs = Seq('a.string, 'b.int)
- val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
- val expected: Expression = NewInstance(
- cls,
- Seq(
- toExternalString('a.string),
- AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long")
- ),
- ObjectType(cls),
- propagateNull = false)
- compareExpressions(fromRowExpr, expected)
- }
+ // int type can be up cast to long type
+ val attrs1 = Seq('a.string, 'b.int)
+ encoder.resolve(attrs1, null).bind(attrs1).fromRow(InternalRow(str, 1))
- {
- val attrs = Seq('a.int, 'b.long)
- val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression
- val expected = NewInstance(
- cls,
- Seq(
- toExternalString('a.int.cast(StringType)),
- AssertNotNull('b.long, cls.getName, "b", "Long")
- ),
- ObjectType(cls),
- propagateNull = false)
- compareExpressions(fromRowExpr, expected)
- }
+ // int type can be up cast to string type
+ val attrs2 = Seq('a.int, 'b.long)
+ encoder.resolve(attrs2, null).bind(attrs2).fromRow(InternalRow(1, 2L))
}
test("real type doesn't match encoder schema but they are compatible: nested product") {
val encoder = ExpressionEncoder[ComplexClass]
- val innerCls = classOf[StringLongClass]
- val cls = classOf[ComplexClass]
-
val attrs = Seq('a.int, 'b.struct('a.int, 'b.long))
- val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
- val expected: Expression = NewInstance(
- cls,
- Seq(
- AssertNotNull('a.int.cast(LongType), cls.getName, "a", "Long"),
- If(
- 'b.struct('a.int, 'b.long).isNull,
- Literal.create(null, ObjectType(innerCls)),
- NewInstance(
- innerCls,
- Seq(
- toExternalString(
- GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)),
- AssertNotNull(
- GetStructField('b.struct('a.int, 'b.long), 1, Some("b")),
- innerCls.getName, "b", "Long")),
- ObjectType(innerCls),
- propagateNull = false)
- )),
- ObjectType(cls),
- propagateNull = false)
- compareExpressions(fromRowExpr, expected)
+ encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L)))
}
test("real type doesn't match encoder schema but they are compatible: tupled encoder") {
val encoder = ExpressionEncoder.tuple(
ExpressionEncoder[StringLongClass],
ExpressionEncoder[Long])
- val cls = classOf[StringLongClass]
-
val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int)
- val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
- val expected: Expression = NewInstance(
- classOf[Tuple2[_, _]],
- Seq(
- NewInstance(
- cls,
- Seq(
- toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))),
- AssertNotNull(
- GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType),
- cls.getName, "b", "Long")),
- ObjectType(cls),
- propagateNull = false),
- 'b.int.cast(LongType)),
- ObjectType(classOf[Tuple2[_, _]]),
- propagateNull = false)
- compareExpressions(fromRowExpr, expected)
+ encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2))
+ }
+
+ test("nullability of array type element should not fail analysis") {
+ val encoder = ExpressionEncoder[Seq[Int]]
+ val attrs = 'a.array(IntegerType) :: Nil
+
+ // It should pass analysis
+ val bound = encoder.resolve(attrs, null).bind(attrs)
+
+ // If no null values appear, it should works fine
+ bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2))))
+
+ // If there is null value, it should throw runtime exception
+ val e = intercept[RuntimeException] {
+ bound.fromRow(InternalRow(new GenericArrayData(Array(1, null))))
+ }
+ assert(e.getMessage.contains("Null value appeared in non-nullable field"))
}
test("the real number of fields doesn't match encoder schema: tuple encoder") {
@@ -166,10 +123,6 @@ class EncoderResolutionSuite extends PlanTest {
}
}
- private def toExternalString(e: Expression): Expression = {
- Invoke(e, "toString", ObjectType(classOf[String]), Nil)
- }
-
test("throw exception if real type is not compatible with encoder schema") {
val msg1 = intercept[AnalysisException] {
ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null)
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index a6fb62c17d..1181244c8a 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -850,9 +850,7 @@ public class JavaDatasetSuite implements Serializable {
}
nullabilityCheck.expect(RuntimeException.class);
- nullabilityCheck.expectMessage(
- "Null value appeared in non-nullable field " +
- "test.org.apache.spark.sql.JavaDatasetSuite$SmallBean.b of type int.");
+ nullabilityCheck.expectMessage("Null value appeared in non-nullable field");
{
Row row = new GenericRow(new Object[] {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 374f4320a9..f9ba607700 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -553,9 +553,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
buildDataset(Row(Row("hello", null))).collect()
}.getMessage
- assert(message.contains(
- "Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int."
- ))
+ assert(message.contains("Null value appeared in non-nullable field"))
}
test("SPARK-12478: top level null field") {