aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala3
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java126
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala33
3 files changed, 160 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index d201d65238..a763a95144 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import scala.collection.JavaConverters._
+import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.function._
import org.apache.spark.rdd.RDD
@@ -64,7 +65,7 @@ import org.apache.spark.util.Utils
class Dataset[T] private[sql](
@transient override val sqlContext: SQLContext,
@transient override val queryExecution: QueryExecution,
- tEncoder: Encoder[T]) extends Queryable with Serializable {
+ tEncoder: Encoder[T]) extends Queryable with Serializable with Logging {
/**
* An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 0dbaeb81c7..9f8db39e33 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -23,6 +23,8 @@ import java.sql.Date;
import java.sql.Timestamp;
import java.util.*;
+import com.google.common.base.Objects;
+import org.junit.rules.ExpectedException;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
@@ -39,7 +41,6 @@ import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.catalyst.encoders.OuterScopes;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
-import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.StructType;
import static org.apache.spark.sql.functions.*;
@@ -741,4 +742,127 @@ public class JavaDatasetSuite implements Serializable {
context.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class));
ds.collect();
}
+
+ public class SmallBean implements Serializable {
+ private String a;
+
+ private int b;
+
+ public int getB() {
+ return b;
+ }
+
+ public void setB(int b) {
+ this.b = b;
+ }
+
+ public String getA() {
+ return a;
+ }
+
+ public void setA(String a) {
+ this.a = a;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ SmallBean smallBean = (SmallBean) o;
+ return b == smallBean.b && com.google.common.base.Objects.equal(a, smallBean.a);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(a, b);
+ }
+ }
+
+ public class NestedSmallBean implements Serializable {
+ private SmallBean f;
+
+ public SmallBean getF() {
+ return f;
+ }
+
+ public void setF(SmallBean f) {
+ this.f = f;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ NestedSmallBean that = (NestedSmallBean) o;
+ return Objects.equal(f, that.f);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(f);
+ }
+ }
+
+ @Rule
+ public transient ExpectedException nullabilityCheck = ExpectedException.none();
+
+ @Test
+ public void testRuntimeNullabilityCheck() {
+ OuterScopes.addOuterScope(this);
+
+ StructType schema = new StructType()
+ .add("f", new StructType()
+ .add("a", StringType, true)
+ .add("b", IntegerType, true), true);
+
+ // Shouldn't throw runtime exception since it passes nullability check.
+ {
+ Row row = new GenericRow(new Object[] {
+ new GenericRow(new Object[] {
+ "hello", 1
+ })
+ });
+
+ DataFrame df = context.createDataFrame(Collections.singletonList(row), schema);
+ Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
+
+ SmallBean smallBean = new SmallBean();
+ smallBean.setA("hello");
+ smallBean.setB(1);
+
+ NestedSmallBean nestedSmallBean = new NestedSmallBean();
+ nestedSmallBean.setF(smallBean);
+
+ Assert.assertEquals(ds.collectAsList(), Collections.singletonList(nestedSmallBean));
+ }
+
+ // Shouldn't throw runtime exception when parent object (`ClassData`) is null
+ {
+ Row row = new GenericRow(new Object[] { null });
+
+ DataFrame df = context.createDataFrame(Collections.singletonList(row), schema);
+ Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
+
+ NestedSmallBean nestedSmallBean = new NestedSmallBean();
+ Assert.assertEquals(ds.collectAsList(), Collections.singletonList(nestedSmallBean));
+ }
+
+ nullabilityCheck.expect(RuntimeException.class);
+ nullabilityCheck.expectMessage(
+ "Null value appeared in non-nullable field " +
+ "test.org.apache.spark.sql.JavaDatasetSuite$SmallBean.b of type int.");
+
+ {
+ Row row = new GenericRow(new Object[] {
+ new GenericRow(new Object[] {
+ "hello", null
+ })
+ });
+
+ DataFrame df = context.createDataFrame(Collections.singletonList(row), schema);
+ Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
+
+ ds.collect();
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index de012a9a56..3337996309 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -24,6 +24,7 @@ import scala.language.postfixOps
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
class DatasetSuite extends QueryTest with SharedSQLContext {
@@ -515,12 +516,44 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
}
assert(e.getMessage.contains("cannot resolve 'c' given input columns a, b"), e.getMessage)
}
+
+ test("runtime nullability check") {
+ val schema = StructType(Seq(
+ StructField("f", StructType(Seq(
+ StructField("a", StringType, nullable = true),
+ StructField("b", IntegerType, nullable = false)
+ )), nullable = true)
+ ))
+
+ def buildDataset(rows: Row*): Dataset[NestedStruct] = {
+ val rowRDD = sqlContext.sparkContext.parallelize(rows)
+ sqlContext.createDataFrame(rowRDD, schema).as[NestedStruct]
+ }
+
+ checkAnswer(
+ buildDataset(Row(Row("hello", 1))),
+ NestedStruct(ClassData("hello", 1))
+ )
+
+ // Shouldn't throw runtime exception when parent object (`ClassData`) is null
+ assert(buildDataset(Row(null)).collect() === Array(NestedStruct(null)))
+
+ val message = intercept[RuntimeException] {
+ buildDataset(Row(Row("hello", null))).collect()
+ }.getMessage
+
+ assert(message.contains(
+ "Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int."
+ ))
+ }
}
case class ClassData(a: String, b: Int)
case class ClassData2(c: String, d: Int)
case class ClassNullableData(a: String, b: Integer)
+case class NestedStruct(f: ClassData)
+
/**
* A class used to test serialization using encoders. This class throws exceptions when using
* Java serialization -- so the only way it can be "serialized" is through our encoders.