aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala40
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala3
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java126
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala33
7 files changed, 232 insertions, 10 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index f566d1b3ca..a1500cbc30 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -288,7 +288,14 @@ object JavaTypeInference {
val setters = properties.map { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
- p.getWriteMethod.getName -> constructorFor(fieldType, Some(addToPath(fieldName)))
+ val (_, nullable) = inferDataType(fieldType)
+ val constructor = constructorFor(fieldType, Some(addToPath(fieldName)))
+ val setter = if (nullable) {
+ constructor
+ } else {
+ AssertNotNull(constructor, other.getName, fieldName, fieldType.toString)
+ }
+ p.getWriteMethod.getName -> setter
}.toMap
val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index cc9e6af181..becd019cae 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -326,7 +326,7 @@ object ScalaReflection extends ScalaReflection {
val cls = getClassFromType(tpe)
val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) =>
- val dataType = schemaFor(fieldType).dataType
+ val Schema(dataType, nullable) = schemaFor(fieldType)
val clsName = getClassNameFromType(fieldType)
val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
// For tuples, we based grab the inner fields by ordinal instead of name.
@@ -336,10 +336,16 @@ object ScalaReflection extends ScalaReflection {
Some(addToPathOrdinal(i, dataType, newTypePath)),
newTypePath)
} else {
- constructorFor(
+ val constructor = constructorFor(
fieldType,
Some(addToPath(fieldName, dataType, newTypePath)),
newTypePath)
+
+ if (!nullable) {
+ AssertNotNull(constructor, t.toString, fieldName, fieldType.toString)
+ } else {
+ constructor
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 492cc9bf41..d40cd96905 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -624,3 +624,43 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
"""
}
}
+
+/**
+ * Asserts that input values of a non-nullable child expression are not null.
+ *
+ * Note that there are cases where `child.nullable == true`, while we still needs to add this
+ * assertion. Consider a nullable column `s` whose data type is a struct containing a non-nullable
+ * `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all
+ * non-null `s`, `s.i` can't be null.
+ */
+case class AssertNotNull(
+ child: Expression, parentType: String, fieldName: String, fieldType: String)
+ extends UnaryExpression {
+
+ override def dataType: DataType = child.dataType
+
+ override def nullable: Boolean = false
+
+ override def eval(input: InternalRow): Any =
+ throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
+
+ override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val childGen = child.gen(ctx)
+
+ ev.isNull = "false"
+ ev.value = childGen.value
+
+ s"""
+ ${childGen.code}
+
+ if (${childGen.isNull}) {
+ throw new RuntimeException(
+ "Null value appeared in non-nullable field $parentType.$fieldName of type $fieldType. " +
+ "If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
+ "please try to use scala.Option[_] or other nullable types " +
+ "(e.g. java.lang.Integer instead of int/scala.Int)."
+ );
+ }
+ """
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index 815a03f7c1..764ffdc094 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -36,12 +36,16 @@ class EncoderResolutionSuite extends PlanTest {
val encoder = ExpressionEncoder[StringLongClass]
val cls = classOf[StringLongClass]
+
{
val attrs = Seq('a.string, 'b.int)
val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
val expected: Expression = NewInstance(
cls,
- toExternalString('a.string) :: 'b.int.cast(LongType) :: Nil,
+ Seq(
+ toExternalString('a.string),
+ AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long")
+ ),
false,
ObjectType(cls))
compareExpressions(fromRowExpr, expected)
@@ -52,7 +56,10 @@ class EncoderResolutionSuite extends PlanTest {
val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression
val expected = NewInstance(
cls,
- toExternalString('a.int.cast(StringType)) :: 'b.long :: Nil,
+ Seq(
+ toExternalString('a.int.cast(StringType)),
+ AssertNotNull('b.long, cls.getName, "b", "Long")
+ ),
false,
ObjectType(cls))
compareExpressions(fromRowExpr, expected)
@@ -69,7 +76,7 @@ class EncoderResolutionSuite extends PlanTest {
val expected: Expression = NewInstance(
cls,
Seq(
- 'a.int.cast(LongType),
+ AssertNotNull('a.int.cast(LongType), cls.getName, "a", "Long"),
If(
'b.struct('a.int, 'b.long).isNull,
Literal.create(null, ObjectType(innerCls)),
@@ -78,7 +85,9 @@ class EncoderResolutionSuite extends PlanTest {
Seq(
toExternalString(
GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)),
- GetStructField('b.struct('a.int, 'b.long), 1, Some("b"))),
+ AssertNotNull(
+ GetStructField('b.struct('a.int, 'b.long), 1, Some("b")),
+ innerCls.getName, "b", "Long")),
false,
ObjectType(innerCls))
)),
@@ -102,7 +111,9 @@ class EncoderResolutionSuite extends PlanTest {
cls,
Seq(
toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))),
- GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType)),
+ AssertNotNull(
+ GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType),
+ cls.getName, "b", "Long")),
false,
ObjectType(cls)),
'b.int.cast(LongType)),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index d201d65238..a763a95144 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import scala.collection.JavaConverters._
+import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.function._
import org.apache.spark.rdd.RDD
@@ -64,7 +65,7 @@ import org.apache.spark.util.Utils
class Dataset[T] private[sql](
@transient override val sqlContext: SQLContext,
@transient override val queryExecution: QueryExecution,
- tEncoder: Encoder[T]) extends Queryable with Serializable {
+ tEncoder: Encoder[T]) extends Queryable with Serializable with Logging {
/**
* An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 0dbaeb81c7..9f8db39e33 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -23,6 +23,8 @@ import java.sql.Date;
import java.sql.Timestamp;
import java.util.*;
+import com.google.common.base.Objects;
+import org.junit.rules.ExpectedException;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
@@ -39,7 +41,6 @@ import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.catalyst.encoders.OuterScopes;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
-import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.StructType;
import static org.apache.spark.sql.functions.*;
@@ -741,4 +742,127 @@ public class JavaDatasetSuite implements Serializable {
context.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class));
ds.collect();
}
+
+ public class SmallBean implements Serializable {
+ private String a;
+
+ private int b;
+
+ public int getB() {
+ return b;
+ }
+
+ public void setB(int b) {
+ this.b = b;
+ }
+
+ public String getA() {
+ return a;
+ }
+
+ public void setA(String a) {
+ this.a = a;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ SmallBean smallBean = (SmallBean) o;
+ return b == smallBean.b && com.google.common.base.Objects.equal(a, smallBean.a);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(a, b);
+ }
+ }
+
+ public class NestedSmallBean implements Serializable {
+ private SmallBean f;
+
+ public SmallBean getF() {
+ return f;
+ }
+
+ public void setF(SmallBean f) {
+ this.f = f;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ NestedSmallBean that = (NestedSmallBean) o;
+ return Objects.equal(f, that.f);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(f);
+ }
+ }
+
+ @Rule
+ public transient ExpectedException nullabilityCheck = ExpectedException.none();
+
+ @Test
+ public void testRuntimeNullabilityCheck() {
+ OuterScopes.addOuterScope(this);
+
+ StructType schema = new StructType()
+ .add("f", new StructType()
+ .add("a", StringType, true)
+ .add("b", IntegerType, true), true);
+
+ // Shouldn't throw runtime exception since it passes nullability check.
+ {
+ Row row = new GenericRow(new Object[] {
+ new GenericRow(new Object[] {
+ "hello", 1
+ })
+ });
+
+ DataFrame df = context.createDataFrame(Collections.singletonList(row), schema);
+ Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
+
+ SmallBean smallBean = new SmallBean();
+ smallBean.setA("hello");
+ smallBean.setB(1);
+
+ NestedSmallBean nestedSmallBean = new NestedSmallBean();
+ nestedSmallBean.setF(smallBean);
+
+ Assert.assertEquals(ds.collectAsList(), Collections.singletonList(nestedSmallBean));
+ }
+
+ // Shouldn't throw runtime exception when parent object (`ClassData`) is null
+ {
+ Row row = new GenericRow(new Object[] { null });
+
+ DataFrame df = context.createDataFrame(Collections.singletonList(row), schema);
+ Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
+
+ NestedSmallBean nestedSmallBean = new NestedSmallBean();
+ Assert.assertEquals(ds.collectAsList(), Collections.singletonList(nestedSmallBean));
+ }
+
+ nullabilityCheck.expect(RuntimeException.class);
+ nullabilityCheck.expectMessage(
+ "Null value appeared in non-nullable field " +
+ "test.org.apache.spark.sql.JavaDatasetSuite$SmallBean.b of type int.");
+
+ {
+ Row row = new GenericRow(new Object[] {
+ new GenericRow(new Object[] {
+ "hello", null
+ })
+ });
+
+ DataFrame df = context.createDataFrame(Collections.singletonList(row), schema);
+ Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
+
+ ds.collect();
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index de012a9a56..3337996309 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -24,6 +24,7 @@ import scala.language.postfixOps
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
class DatasetSuite extends QueryTest with SharedSQLContext {
@@ -515,12 +516,44 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
}
assert(e.getMessage.contains("cannot resolve 'c' given input columns a, b"), e.getMessage)
}
+
+ test("runtime nullability check") {
+ val schema = StructType(Seq(
+ StructField("f", StructType(Seq(
+ StructField("a", StringType, nullable = true),
+ StructField("b", IntegerType, nullable = false)
+ )), nullable = true)
+ ))
+
+ def buildDataset(rows: Row*): Dataset[NestedStruct] = {
+ val rowRDD = sqlContext.sparkContext.parallelize(rows)
+ sqlContext.createDataFrame(rowRDD, schema).as[NestedStruct]
+ }
+
+ checkAnswer(
+ buildDataset(Row(Row("hello", 1))),
+ NestedStruct(ClassData("hello", 1))
+ )
+
+ // Shouldn't throw runtime exception when parent object (`ClassData`) is null
+ assert(buildDataset(Row(null)).collect() === Array(NestedStruct(null)))
+
+ val message = intercept[RuntimeException] {
+ buildDataset(Row(Row("hello", null))).collect()
+ }.getMessage
+
+ assert(message.contains(
+ "Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int."
+ ))
+ }
}
case class ClassData(a: String, b: Int)
case class ClassData2(c: String, d: Int)
case class ClassNullableData(a: String, b: Integer)
+case class NestedStruct(f: ClassData)
+
/**
* A class used to test serialization using encoders. This class throws exceptions when using
* Java serialization -- so the only way it can be "serialized" is through our encoders.