aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorTakuya UESHIN <ueshin@happy-camper.st>2014-08-26 13:22:55 -0700
committerMichael Armbrust <michael@databricks.com>2014-08-26 13:22:55 -0700
commit98c2bb0bbde6fb2b6f64af3efffefcb0dae94c12 (patch)
treeded21f0b71756a5d03c9c77cad09f90fadd69d20 /sql
parent3cedc4f4d78e093fd362085e0a077bb9e4f28ca5 (diff)
downloadspark-98c2bb0bbde6fb2b6f64af3efffefcb0dae94c12.tar.gz
spark-98c2bb0bbde6fb2b6f64af3efffefcb0dae94c12.tar.bz2
spark-98c2bb0bbde6fb2b6f64af3efffefcb0dae94c12.zip
[SPARK-2969][SQL] Make ScalaReflection be able to handle ArrayType.containsNull and MapType.valueContainsNull.
Make `ScalaReflection` be able to handle like: - `Seq[Int]` as `ArrayType(IntegerType, containsNull = false)` - `Seq[java.lang.Integer]` as `ArrayType(IntegerType, containsNull = true)` - `Map[Int, Long]` as `MapType(IntegerType, LongType, valueContainsNull = false)` - `Map[Int, java.lang.Long]` as `MapType(IntegerType, LongType, valueContainsNull = true)` Author: Takuya UESHIN <ueshin@happy-camper.st> Closes #1889 from ueshin/issues/SPARK-2969 and squashes the following commits: 24f1c5c [Takuya UESHIN] Change the default value of ArrayType.containsNull to true in Python API. 79f5b65 [Takuya UESHIN] Change the default value of ArrayType.containsNull to true in Java API. 7cd1a7a [Takuya UESHIN] Fix json test failures. 2cfb862 [Takuya UESHIN] Change the default value of ArrayType.containsNull to true. 2f38e61 [Takuya UESHIN] Revert the default value of MapTypes.valueContainsNull. 9fa02f5 [Takuya UESHIN] Fix a test failure. 1a9a96b [Takuya UESHIN] Modify ScalaReflection to handle ArrayType.containsNull and MapType.valueContainsNull.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala22
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala32
6 files changed, 46 insertions, 27 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 0d26b52a84..6b6b636cd9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -62,11 +62,14 @@ object ScalaReflection {
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
- Schema(ArrayType(schemaFor(elementType).dataType), nullable = true)
+ val Schema(dataType, nullable) = schemaFor(elementType)
+ Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Map[_,_]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
- Schema(MapType(schemaFor(keyType).dataType, schemaFor(valueType).dataType), nullable = true)
- case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
+ val Schema(valueDataType, valueNullable) = schemaFor(valueType)
+ Schema(MapType(schemaFor(keyType).dataType,
+ valueDataType, valueContainsNull = valueNullable), nullable = true)
+ case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index b52ee6d337..70c6d06cf2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -270,8 +270,8 @@ case object FloatType extends FractionalType {
}
object ArrayType {
- /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is false. */
- def apply(elementType: DataType): ArrayType = ArrayType(elementType, false)
+ /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */
+ def apply(elementType: DataType): ArrayType = ArrayType(elementType, true)
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index e75373d5a7..428607d8c8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -57,7 +57,9 @@ case class OptionalData(
case class ComplexData(
arrayField: Seq[Int],
- mapField: Map[Int, String],
+ arrayFieldContainsNull: Seq[java.lang.Integer],
+ mapField: Map[Int, Long],
+ mapFieldValueContainsNull: Map[Int, java.lang.Long],
structField: PrimitiveData)
case class GenericData[A](
@@ -116,8 +118,22 @@ class ScalaReflectionSuite extends FunSuite {
val schema = schemaFor[ComplexData]
assert(schema === Schema(
StructType(Seq(
- StructField("arrayField", ArrayType(IntegerType), nullable = true),
- StructField("mapField", MapType(IntegerType, StringType), nullable = true),
+ StructField(
+ "arrayField",
+ ArrayType(IntegerType, containsNull = false),
+ nullable = true),
+ StructField(
+ "arrayFieldContainsNull",
+ ArrayType(IntegerType, containsNull = true),
+ nullable = true),
+ StructField(
+ "mapField",
+ MapType(IntegerType, LongType, valueContainsNull = false),
+ nullable = true),
+ StructField(
+ "mapFieldValueContainsNull",
+ MapType(IntegerType, LongType, valueContainsNull = true),
+ nullable = true),
StructField(
"structField",
StructType(Seq(
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java
index 3eccddef88..37b4c8ffcb 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java
@@ -86,14 +86,14 @@ public abstract class DataType {
/**
* Creates an ArrayType by specifying the data type of elements ({@code elementType}).
- * The field of {@code containsNull} is set to {@code false}.
+ * The field of {@code containsNull} is set to {@code true}.
*/
public static ArrayType createArrayType(DataType elementType) {
if (elementType == null) {
throw new IllegalArgumentException("elementType should not be null.");
}
- return new ArrayType(elementType, false);
+ return new ArrayType(elementType, true);
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
index cf7d79f42d..8fb59c5830 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
@@ -24,7 +24,7 @@ class DataTypeSuite extends FunSuite {
test("construct an ArrayType") {
val array = ArrayType(StringType)
- assert(ArrayType(StringType, false) === array)
+ assert(ArrayType(StringType, true) === array)
}
test("construct an MapType") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 58b1e23891..05513a1271 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -130,11 +130,11 @@ class JsonSuite extends QueryTest {
checkDataType(
ArrayType(IntegerType, true), ArrayType(IntegerType, true), ArrayType(IntegerType, true))
checkDataType(
- ArrayType(IntegerType, false), ArrayType(IntegerType), ArrayType(IntegerType, false))
+ ArrayType(IntegerType, false), ArrayType(IntegerType), ArrayType(IntegerType, true))
checkDataType(
ArrayType(IntegerType, false), ArrayType(IntegerType, false), ArrayType(IntegerType, false))
checkDataType(
- ArrayType(IntegerType, false), ArrayType(IntegerType, false), ArrayType(IntegerType))
+ ArrayType(IntegerType, false), ArrayType(IntegerType, true), ArrayType(IntegerType, true))
// StructType
checkDataType(StructType(Nil), StructType(Nil), StructType(Nil))
@@ -201,26 +201,26 @@ class JsonSuite extends QueryTest {
val jsonSchemaRDD = jsonRDD(complexFieldAndType)
val expectedSchema = StructType(
- StructField("arrayOfArray1", ArrayType(ArrayType(StringType)), true) ::
- StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType)), true) ::
- StructField("arrayOfBigInteger", ArrayType(DecimalType), true) ::
- StructField("arrayOfBoolean", ArrayType(BooleanType), true) ::
- StructField("arrayOfDouble", ArrayType(DoubleType), true) ::
- StructField("arrayOfInteger", ArrayType(IntegerType), true) ::
- StructField("arrayOfLong", ArrayType(LongType), true) ::
+ StructField("arrayOfArray1", ArrayType(ArrayType(StringType, false), false), true) ::
+ StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, false), false), true) ::
+ StructField("arrayOfBigInteger", ArrayType(DecimalType, false), true) ::
+ StructField("arrayOfBoolean", ArrayType(BooleanType, false), true) ::
+ StructField("arrayOfDouble", ArrayType(DoubleType, false), true) ::
+ StructField("arrayOfInteger", ArrayType(IntegerType, false), true) ::
+ StructField("arrayOfLong", ArrayType(LongType, false), true) ::
StructField("arrayOfNull", ArrayType(StringType, true), true) ::
- StructField("arrayOfString", ArrayType(StringType), true) ::
+ StructField("arrayOfString", ArrayType(StringType, false), true) ::
StructField("arrayOfStruct", ArrayType(
StructType(
StructField("field1", BooleanType, true) ::
StructField("field2", StringType, true) ::
- StructField("field3", StringType, true) :: Nil)), true) ::
+ StructField("field3", StringType, true) :: Nil), false), true) ::
StructField("struct", StructType(
StructField("field1", BooleanType, true) ::
StructField("field2", DecimalType, true) :: Nil), true) ::
StructField("structWithArrayFields", StructType(
- StructField("field1", ArrayType(IntegerType), true) ::
- StructField("field2", ArrayType(StringType), true) :: Nil), true) :: Nil)
+ StructField("field1", ArrayType(IntegerType, false), true) ::
+ StructField("field2", ArrayType(StringType, false), true) :: Nil), true) :: Nil)
assert(expectedSchema === jsonSchemaRDD.schema)
@@ -441,7 +441,7 @@ class JsonSuite extends QueryTest {
val jsonSchemaRDD = jsonRDD(complexFieldValueTypeConflict)
val expectedSchema = StructType(
- StructField("array", ArrayType(IntegerType), true) ::
+ StructField("array", ArrayType(IntegerType, false), true) ::
StructField("num_struct", StringType, true) ::
StructField("str_array", StringType, true) ::
StructField("struct", StructType(
@@ -467,7 +467,7 @@ class JsonSuite extends QueryTest {
val expectedSchema = StructType(
StructField("array1", ArrayType(StringType, true), true) ::
StructField("array2", ArrayType(StructType(
- StructField("field", LongType, true) :: Nil)), true) :: Nil)
+ StructField("field", LongType, true) :: Nil), false), true) :: Nil)
assert(expectedSchema === jsonSchemaRDD.schema)
@@ -492,7 +492,7 @@ class JsonSuite extends QueryTest {
val expectedSchema = StructType(
StructField("a", BooleanType, true) ::
StructField("b", LongType, true) ::
- StructField("c", ArrayType(IntegerType), true) ::
+ StructField("c", ArrayType(IntegerType, false), true) ::
StructField("d", StructType(
StructField("field", BooleanType, true) :: Nil), true) ::
StructField("e", StringType, true) :: Nil)