aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala29
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala20
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala107
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala4
7 files changed, 64 insertions, 104 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 3c3717d504..59ee41d02f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -292,7 +292,7 @@ object JavaTypeInference {
val setter = if (nullable) {
constructor
} else {
- AssertNotNull(constructor, other.getName, fieldName, fieldType.toString)
+ AssertNotNull(constructor, Seq("currently no type path record in java"))
}
p.getWriteMethod.getName -> setter
}.toMap
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index e5811efb43..02cb2d9a2b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -249,6 +249,8 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
+
+ // TODO: add runtime null check for primitive array
val primitiveMethod = elementType match {
case t if t <:< definitions.IntTpe => Some("toIntArray")
case t if t <:< definitions.LongTpe => Some("toLongArray")
@@ -276,22 +278,29 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
+ val Schema(dataType, nullable) = schemaFor(elementType)
val className = getClassNameFromType(elementType)
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
- val arrayData =
- Invoke(
- MapObjects(
- p => constructorFor(elementType, Some(p), newTypePath),
- getPath,
- schemaFor(elementType).dataType),
- "array",
- ObjectType(classOf[Array[Any]]))
+
+ val mapFunction: Expression => Expression = p => {
+ val converter = constructorFor(elementType, Some(p), newTypePath)
+ if (nullable) {
+ converter
+ } else {
+ AssertNotNull(converter, newTypePath)
+ }
+ }
+
+ val array = Invoke(
+ MapObjects(mapFunction, getPath, dataType),
+ "array",
+ ObjectType(classOf[Array[Any]]))
StaticInvoke(
scala.collection.mutable.WrappedArray.getClass,
ObjectType(classOf[Seq[_]]),
"make",
- arrayData :: Nil)
+ array :: Nil)
case t if t <:< localTypeOf[Map[_, _]] =>
// TODO: add walked type path for map
@@ -343,7 +352,7 @@ object ScalaReflection extends ScalaReflection {
newTypePath)
if (!nullable) {
- AssertNotNull(constructor, t.toString, fieldName, fieldType.toString)
+ AssertNotNull(constructor, newTypePath)
} else {
constructor
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index cb228cf52b..4d53b232d5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1426,7 +1426,7 @@ object ResolveUpCast extends Rule[LogicalPlan] {
fail(child, DateType, walkedTypePath)
case (StringType, to: NumericType) =>
fail(child, to, walkedTypePath)
- case _ => Cast(child, dataType)
+ case _ => Cast(child, dataType.asNullable)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 79fe0033b7..fef6825b2d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -365,7 +365,7 @@ object MapObjects {
* to handle collection elements.
* @param inputData An expression that when evaluted returns a collection object.
*/
-case class MapObjects(
+case class MapObjects private(
loopVar: LambdaVariable,
lambdaFunction: Expression,
inputData: Expression) extends Expression {
@@ -637,8 +637,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
* `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all
* non-null `s`, `s.i` can't be null.
*/
-case class AssertNotNull(
- child: Expression, parentType: String, fieldName: String, fieldType: String)
+case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
extends UnaryExpression {
override def dataType: DataType = child.dataType
@@ -651,6 +650,14 @@ case class AssertNotNull(
override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val childGen = child.gen(ctx)
+ val errMsg = "Null value appeared in non-nullable field:" +
+ walkedTypePath.mkString("\n", "\n", "\n") +
+ "If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
+ "please try to use scala.Option[_] or other nullable types " +
+ "(e.g. java.lang.Integer instead of int/scala.Int)."
+ val idx = ctx.references.length
+ ctx.references += errMsg
+
ev.isNull = "false"
ev.value = childGen.value
@@ -658,12 +665,7 @@ case class AssertNotNull(
${childGen.code}
if (${childGen.isNull}) {
- throw new RuntimeException(
- "Null value appeared in non-nullable field $parentType.$fieldName of type $fieldType. " +
- "If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
- "please try to use scala.Option[_] or other nullable types " +
- "(e.g. java.lang.Integer instead of int/scala.Int)."
- );
+ throw new RuntimeException((String) references[$idx]);
}
"""
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index 92a68a4dba..8b02b63c6c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -21,9 +21,11 @@ import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
case class StringLongClass(a: String, b: Long)
@@ -32,94 +34,49 @@ case class StringIntClass(a: String, b: Int)
case class ComplexClass(a: Long, b: StringLongClass)
class EncoderResolutionSuite extends PlanTest {
+ private val str = UTF8String.fromString("hello")
+
test("real type doesn't match encoder schema but they are compatible: product") {
val encoder = ExpressionEncoder[StringLongClass]
- val cls = classOf[StringLongClass]
-
- {
- val attrs = Seq('a.string, 'b.int)
- val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
- val expected: Expression = NewInstance(
- cls,
- Seq(
- toExternalString('a.string),
- AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long")
- ),
- ObjectType(cls),
- propagateNull = false)
- compareExpressions(fromRowExpr, expected)
- }
+ // int type can be up cast to long type
+ val attrs1 = Seq('a.string, 'b.int)
+ encoder.resolve(attrs1, null).bind(attrs1).fromRow(InternalRow(str, 1))
- {
- val attrs = Seq('a.int, 'b.long)
- val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression
- val expected = NewInstance(
- cls,
- Seq(
- toExternalString('a.int.cast(StringType)),
- AssertNotNull('b.long, cls.getName, "b", "Long")
- ),
- ObjectType(cls),
- propagateNull = false)
- compareExpressions(fromRowExpr, expected)
- }
+ // int type can be up cast to string type
+ val attrs2 = Seq('a.int, 'b.long)
+ encoder.resolve(attrs2, null).bind(attrs2).fromRow(InternalRow(1, 2L))
}
test("real type doesn't match encoder schema but they are compatible: nested product") {
val encoder = ExpressionEncoder[ComplexClass]
- val innerCls = classOf[StringLongClass]
- val cls = classOf[ComplexClass]
-
val attrs = Seq('a.int, 'b.struct('a.int, 'b.long))
- val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
- val expected: Expression = NewInstance(
- cls,
- Seq(
- AssertNotNull('a.int.cast(LongType), cls.getName, "a", "Long"),
- If(
- 'b.struct('a.int, 'b.long).isNull,
- Literal.create(null, ObjectType(innerCls)),
- NewInstance(
- innerCls,
- Seq(
- toExternalString(
- GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)),
- AssertNotNull(
- GetStructField('b.struct('a.int, 'b.long), 1, Some("b")),
- innerCls.getName, "b", "Long")),
- ObjectType(innerCls),
- propagateNull = false)
- )),
- ObjectType(cls),
- propagateNull = false)
- compareExpressions(fromRowExpr, expected)
+ encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L)))
}
test("real type doesn't match encoder schema but they are compatible: tupled encoder") {
val encoder = ExpressionEncoder.tuple(
ExpressionEncoder[StringLongClass],
ExpressionEncoder[Long])
- val cls = classOf[StringLongClass]
-
val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int)
- val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
- val expected: Expression = NewInstance(
- classOf[Tuple2[_, _]],
- Seq(
- NewInstance(
- cls,
- Seq(
- toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))),
- AssertNotNull(
- GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType),
- cls.getName, "b", "Long")),
- ObjectType(cls),
- propagateNull = false),
- 'b.int.cast(LongType)),
- ObjectType(classOf[Tuple2[_, _]]),
- propagateNull = false)
- compareExpressions(fromRowExpr, expected)
+ encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2))
+ }
+
+ test("nullability of array type element should not fail analysis") {
+ val encoder = ExpressionEncoder[Seq[Int]]
+ val attrs = 'a.array(IntegerType) :: Nil
+
+ // It should pass analysis
+ val bound = encoder.resolve(attrs, null).bind(attrs)
+
+ // If no null values appear, it should works fine
+ bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2))))
+
+ // If there is null value, it should throw runtime exception
+ val e = intercept[RuntimeException] {
+ bound.fromRow(InternalRow(new GenericArrayData(Array(1, null))))
+ }
+ assert(e.getMessage.contains("Null value appeared in non-nullable field"))
}
test("the real number of fields doesn't match encoder schema: tuple encoder") {
@@ -166,10 +123,6 @@ class EncoderResolutionSuite extends PlanTest {
}
}
- private def toExternalString(e: Expression): Expression = {
- Invoke(e, "toString", ObjectType(classOf[String]), Nil)
- }
-
test("throw exception if real type is not compatible with encoder schema") {
val msg1 = intercept[AnalysisException] {
ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null)
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index a6fb62c17d..1181244c8a 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -850,9 +850,7 @@ public class JavaDatasetSuite implements Serializable {
}
nullabilityCheck.expect(RuntimeException.class);
- nullabilityCheck.expectMessage(
- "Null value appeared in non-nullable field " +
- "test.org.apache.spark.sql.JavaDatasetSuite$SmallBean.b of type int.");
+ nullabilityCheck.expectMessage("Null value appeared in non-nullable field");
{
Row row = new GenericRow(new Object[] {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 374f4320a9..f9ba607700 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -553,9 +553,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
buildDataset(Row(Row("hello", null))).collect()
}.getMessage
- assert(message.contains(
- "Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int."
- ))
+ assert(message.contains("Null value appeared in non-nullable field"))
}
test("SPARK-12478: top level null field") {