aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala97
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala83
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala5
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala33
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala71
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala29
15 files changed, 327 insertions, 35 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 89544add74..20cc8e90a7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -120,7 +120,8 @@ case class InsertIntoTable(
override def output = child.output
override lazy val resolved = childrenResolved && child.output.zip(table.output).forall {
- case (childAttr, tableAttr) => childAttr.dataType == tableAttr.dataType
+ case (childAttr, tableAttr) =>
+ DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
index 2abb1caee9..92d322845f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
@@ -181,7 +181,7 @@ object DataType {
/**
* Compares two types, ignoring nullability of ArrayType, MapType, StructType.
*/
- private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
+ private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
(left, right) match {
case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) =>
equalsIgnoreNullability(leftElementType, rightElementType)
@@ -198,6 +198,43 @@ object DataType {
case (left, right) => left == right
}
}
+
+ /**
+ * Compares two types, ignoring compatible nullability of ArrayType, MapType, StructType.
+ *
+ * Compatible nullability is defined as follows:
+ * - If `from` and `to` are ArrayTypes, `from` has a compatible nullability with `to`
+ * if and only if `to.containsNull` is true, or both of `from.containsNull` and
+ * `to.containsNull` are false.
+ * - If `from` and `to` are MapTypes, `from` has a compatible nullability with `to`
+ * if and only if `to.valueContainsNull` is true, or both of `from.valueContainsNull` and
+ * `to.valueContainsNull` are false.
+ * - If `from` and `to` are StructTypes, `from` has a compatible nullability with `to`
+ * if and only if for all every pair of fields, `to.nullable` is true, or both
+ * of `fromField.nullable` and `toField.nullable` are false.
+ */
+ private[sql] def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = {
+ (from, to) match {
+ case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) =>
+ (tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement)
+
+ case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
+ (tn || !fn) &&
+ equalsIgnoreCompatibleNullability(fromKey, toKey) &&
+ equalsIgnoreCompatibleNullability(fromValue, toValue)
+
+ case (StructType(fromFields), StructType(toFields)) =>
+ fromFields.size == toFields.size &&
+ fromFields.zip(toFields).forall {
+ case (fromField, toField) =>
+ fromField.name == toField.name &&
+ (toField.nullable || !fromField.nullable) &&
+ equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType)
+ }
+
+ case (fromDataType, toDataType) => fromDataType == toDataType
+ }
+ }
}
@@ -230,6 +267,17 @@ abstract class DataType {
def prettyJson: String = pretty(render(jsonValue))
def simpleString: String = typeName
+
+ /** Check if `this` and `other` are the same data type when ignoring nullability
+ * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
+ */
+ private[spark] def sameType(other: DataType): Boolean =
+ DataType.equalsIgnoreNullability(this, other)
+
+ /** Returns the same data type but set all nullability fields are true
+ * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
+ */
+ private[spark] def asNullable: DataType
}
/**
@@ -245,6 +293,8 @@ class NullType private() extends DataType {
// this type. Otherwise, the companion object would be of type "NullType$" in byte code.
// Defined with a private constructor so the companion object is the only possible instantiation.
override def defaultSize: Int = 1
+
+ private[spark] override def asNullable: NullType = this
}
case object NullType extends NullType
@@ -310,6 +360,8 @@ class StringType private() extends NativeType with PrimitiveType {
* The default size of a value of the StringType is 4096 bytes.
*/
override def defaultSize: Int = 4096
+
+ private[spark] override def asNullable: StringType = this
}
case object StringType extends StringType
@@ -344,6 +396,8 @@ class BinaryType private() extends NativeType with PrimitiveType {
* The default size of a value of the BinaryType is 4096 bytes.
*/
override def defaultSize: Int = 4096
+
+ private[spark] override def asNullable: BinaryType = this
}
case object BinaryType extends BinaryType
@@ -369,6 +423,8 @@ class BooleanType private() extends NativeType with PrimitiveType {
* The default size of a value of the BooleanType is 1 byte.
*/
override def defaultSize: Int = 1
+
+ private[spark] override def asNullable: BooleanType = this
}
case object BooleanType extends BooleanType
@@ -399,6 +455,8 @@ class TimestampType private() extends NativeType {
* The default size of a value of the TimestampType is 12 bytes.
*/
override def defaultSize: Int = 12
+
+ private[spark] override def asNullable: TimestampType = this
}
case object TimestampType extends TimestampType
@@ -427,6 +485,8 @@ class DateType private() extends NativeType {
* The default size of a value of the DateType is 4 bytes.
*/
override def defaultSize: Int = 4
+
+ private[spark] override def asNullable: DateType = this
}
case object DateType extends DateType
@@ -485,6 +545,8 @@ class LongType private() extends IntegralType {
override def defaultSize: Int = 8
override def simpleString = "bigint"
+
+ private[spark] override def asNullable: LongType = this
}
case object LongType extends LongType
@@ -514,6 +576,8 @@ class IntegerType private() extends IntegralType {
override def defaultSize: Int = 4
override def simpleString = "int"
+
+ private[spark] override def asNullable: IntegerType = this
}
case object IntegerType extends IntegerType
@@ -543,6 +607,8 @@ class ShortType private() extends IntegralType {
override def defaultSize: Int = 2
override def simpleString = "smallint"
+
+ private[spark] override def asNullable: ShortType = this
}
case object ShortType extends ShortType
@@ -572,6 +638,8 @@ class ByteType private() extends IntegralType {
override def defaultSize: Int = 1
override def simpleString = "tinyint"
+
+ private[spark] override def asNullable: ByteType = this
}
case object ByteType extends ByteType
@@ -638,6 +706,8 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)"
case None => "decimal(10,0)"
}
+
+ private[spark] override def asNullable: DecimalType = this
}
@@ -696,6 +766,8 @@ class DoubleType private() extends FractionalType {
* The default size of a value of the DoubleType is 8 bytes.
*/
override def defaultSize: Int = 8
+
+ private[spark] override def asNullable: DoubleType = this
}
case object DoubleType extends DoubleType
@@ -724,6 +796,8 @@ class FloatType private() extends FractionalType {
* The default size of a value of the FloatType is 4 bytes.
*/
override def defaultSize: Int = 4
+
+ private[spark] override def asNullable: FloatType = this
}
case object FloatType extends FloatType
@@ -772,6 +846,9 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
override def defaultSize: Int = 100 * elementType.defaultSize
override def simpleString = s"array<${elementType.simpleString}>"
+
+ private[spark] override def asNullable: ArrayType =
+ ArrayType(elementType.asNullable, containsNull = true)
}
@@ -1017,6 +1094,15 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
*/
private[sql] def merge(that: StructType): StructType =
StructType.merge(this, that).asInstanceOf[StructType]
+
+ private[spark] override def asNullable: StructType = {
+ val newFields = fields.map {
+ case StructField(name, dataType, nullable, metadata) =>
+ StructField(name, dataType.asNullable, nullable = true, metadata)
+ }
+
+ StructType(newFields)
+ }
}
@@ -1069,6 +1155,9 @@ case class MapType(
override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize)
override def simpleString = s"map<${keyType.simpleString},${valueType.simpleString}>"
+
+ private[spark] override def asNullable: MapType =
+ MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true)
}
@@ -1122,4 +1211,10 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
* The default size of a value of the UserDefinedType is 4096 bytes.
*/
override def defaultSize: Int = 4096
+
+ /**
+ * For UDT, asNullable will not change the nullability of its internal sqlType and just returns
+ * itself.
+ */
+ private[spark] override def asNullable: UserDefinedType[UserType] = this
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index c97e0bec3e..a1341ea13d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -115,4 +115,87 @@ class DataTypeSuite extends FunSuite {
checkDefaultSize(MapType(IntegerType, StringType, true), 410000)
checkDefaultSize(MapType(IntegerType, ArrayType(DoubleType), false), 80400)
checkDefaultSize(structType, 812)
+
+ def checkEqualsIgnoreCompatibleNullability(
+ from: DataType,
+ to: DataType,
+ expected: Boolean): Unit = {
+ val testName =
+ s"equalsIgnoreCompatibleNullability: (from: ${from}, to: ${to})"
+ test(testName) {
+ assert(DataType.equalsIgnoreCompatibleNullability(from, to) === expected)
+ }
+ }
+
+ checkEqualsIgnoreCompatibleNullability(
+ from = ArrayType(DoubleType, containsNull = true),
+ to = ArrayType(DoubleType, containsNull = true),
+ expected = true)
+ checkEqualsIgnoreCompatibleNullability(
+ from = ArrayType(DoubleType, containsNull = false),
+ to = ArrayType(DoubleType, containsNull = false),
+ expected = true)
+ checkEqualsIgnoreCompatibleNullability(
+ from = ArrayType(DoubleType, containsNull = false),
+ to = ArrayType(DoubleType, containsNull = true),
+ expected = true)
+ checkEqualsIgnoreCompatibleNullability(
+ from = ArrayType(DoubleType, containsNull = true),
+ to = ArrayType(DoubleType, containsNull = false),
+ expected = false)
+ checkEqualsIgnoreCompatibleNullability(
+ from = ArrayType(DoubleType, containsNull = false),
+ to = ArrayType(StringType, containsNull = false),
+ expected = false)
+
+ checkEqualsIgnoreCompatibleNullability(
+ from = MapType(StringType, DoubleType, valueContainsNull = true),
+ to = MapType(StringType, DoubleType, valueContainsNull = true),
+ expected = true)
+ checkEqualsIgnoreCompatibleNullability(
+ from = MapType(StringType, DoubleType, valueContainsNull = false),
+ to = MapType(StringType, DoubleType, valueContainsNull = false),
+ expected = true)
+ checkEqualsIgnoreCompatibleNullability(
+ from = MapType(StringType, DoubleType, valueContainsNull = false),
+ to = MapType(StringType, DoubleType, valueContainsNull = true),
+ expected = true)
+ checkEqualsIgnoreCompatibleNullability(
+ from = MapType(StringType, DoubleType, valueContainsNull = true),
+ to = MapType(StringType, DoubleType, valueContainsNull = false),
+ expected = false)
+ checkEqualsIgnoreCompatibleNullability(
+ from = MapType(StringType, ArrayType(IntegerType, true), valueContainsNull = true),
+ to = MapType(StringType, ArrayType(IntegerType, false), valueContainsNull = true),
+ expected = false)
+ checkEqualsIgnoreCompatibleNullability(
+ from = MapType(StringType, ArrayType(IntegerType, false), valueContainsNull = true),
+ to = MapType(StringType, ArrayType(IntegerType, true), valueContainsNull = true),
+ expected = true)
+
+
+ checkEqualsIgnoreCompatibleNullability(
+ from = StructType(StructField("a", StringType, nullable = true) :: Nil),
+ to = StructType(StructField("a", StringType, nullable = true) :: Nil),
+ expected = true)
+ checkEqualsIgnoreCompatibleNullability(
+ from = StructType(StructField("a", StringType, nullable = false) :: Nil),
+ to = StructType(StructField("a", StringType, nullable = false) :: Nil),
+ expected = true)
+ checkEqualsIgnoreCompatibleNullability(
+ from = StructType(StructField("a", StringType, nullable = false) :: Nil),
+ to = StructType(StructField("a", StringType, nullable = true) :: Nil),
+ expected = true)
+ checkEqualsIgnoreCompatibleNullability(
+ from = StructType(StructField("a", StringType, nullable = true) :: Nil),
+ to = StructType(StructField("a", StringType, nullable = false) :: Nil),
+ expected = false)
+ checkEqualsIgnoreCompatibleNullability(
+ from = StructType(
+ StructField("a", StringType, nullable = false) ::
+ StructField("b", StringType, nullable = true) :: Nil),
+ to = StructType(
+ StructField("a", StringType, nullable = false) ::
+ StructField("b", StringType, nullable = false) :: Nil),
+ expected = false)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
index 3b68b7c275..f9d0ba2241 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
@@ -23,7 +23,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext}
import org.apache.spark.sql.sources._
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DataType, StructType}
private[sql] class DefaultSource
@@ -131,7 +131,7 @@ private[sql] case class JSONRelation(
override def equals(other: Any): Boolean = other match {
case that: JSONRelation =>
- (this.path == that.path) && (this.schema == that.schema)
+ (this.path == that.path) && this.schema.sameType(that.schema)
case _ => false
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
index a0d1005c0c..fd161bae12 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
@@ -23,6 +23,7 @@ import java.util.logging.Level
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.fs.permission.FsAction
+import org.apache.spark.sql.types.{StructType, DataType}
import parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat}
import parquet.hadoop.metadata.CompressionCodecName
import parquet.schema.MessageType
@@ -172,9 +173,13 @@ private[sql] object ParquetRelation {
sqlContext.conf.parquetCompressionCodec.toUpperCase, CompressionCodecName.UNCOMPRESSED)
.name())
ParquetRelation.enableLogForwarding()
- ParquetTypesConverter.writeMetaData(attributes, path, conf)
+ // This is a hack. We always set nullable/containsNull/valueContainsNull to true
+ // for the schema of a parquet data.
+ val schema = StructType.fromAttributes(attributes).asNullable
+ val newAttributes = schema.toAttributes
+ ParquetTypesConverter.writeMetaData(newAttributes, path, conf)
new ParquetRelation(path.toString, Some(conf), sqlContext) {
- override val output = attributes
+ override val output = newAttributes
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index 225ec6db7d..62813a981e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -278,7 +278,10 @@ private[sql] case class InsertIntoParquetTable(
ParquetOutputFormat.setWriteSupportClass(job, writeSupport)
val conf = ContextUtil.getConfiguration(job)
- RowWriteSupport.setSchema(relation.output, conf)
+ // This is a hack. We always set nullable/containsNull/valueContainsNull to true
+ // for the schema of a parquet data.
+ val schema = StructType.fromAttributes(relation.output).asNullable
+ RowWriteSupport.setSchema(schema.toAttributes, conf)
val fspath = new Path(relation.path)
val fs = fspath.getFileSystem(conf)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
index 6d56be3ab8..8d95858493 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
@@ -115,9 +115,15 @@ private[sql] class DefaultSource
}
val relation = if (doInsertion) {
+ // This is a hack. We always set nullable/containsNull/valueContainsNull to true
+ // for the schema of a parquet data.
+ val df =
+ sqlContext.createDataFrame(
+ data.queryExecution.toRdd,
+ data.schema.asNullable)
val createdRelation =
- createRelation(sqlContext, parameters, data.schema).asInstanceOf[ParquetRelation2]
- createdRelation.insert(data, overwrite = mode == SaveMode.Overwrite)
+ createRelation(sqlContext, parameters, df.schema).asInstanceOf[ParquetRelation2]
+ createdRelation.insert(df, overwrite = mode == SaveMode.Overwrite)
createdRelation
} else {
// If the save mode is Ignore, we will just create the relation based on existing data.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
index c9cd0e6e93..0e540dad81 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.sources
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.{LogicalRDD, RunnableCommand}
+import org.apache.spark.sql.execution.RunnableCommand
private[sql] case class InsertIntoDataSource(
logicalRelation: LogicalRelation,
@@ -29,7 +29,10 @@ private[sql] case class InsertIntoDataSource(
override def run(sqlContext: SQLContext) = {
val relation = logicalRelation.relation.asInstanceOf[InsertableRelation]
- relation.insert(DataFrame(sqlContext, query), overwrite)
+ val data = DataFrame(sqlContext, query)
+ // Apply the schema of the existing table to the new data.
+ val df = sqlContext.createDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
+ relation.insert(df, overwrite)
// Invalidate the cache.
sqlContext.cacheManager.invalidateCache(logicalRelation)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala
index 8440581074..cfa58f1442 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala
@@ -56,7 +56,7 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] {
child: LogicalPlan) = {
val newChildOutput = expectedOutput.zip(child.output).map {
case (expected, actual) =>
- val needCast = !DataType.equalsIgnoreNullability(expected.dataType, actual.dataType)
+ val needCast = !expected.dataType.sameType(actual.dataType)
// We want to make sure the filed names in the data to be inserted exactly match
// names in the schema.
val needRename = expected.name != actual.name
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
index eb045e37bf..c11d0ae5bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
@@ -59,4 +59,6 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] {
}
override def userClass: Class[ExamplePoint] = classOf[ExamplePoint]
+
+ private[spark] override def asNullable: ExamplePointUDT = this
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 47fdb55432..23f424c0bf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -62,6 +62,8 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
}
override def userClass = classOf[MyDenseVector]
+
+ private[spark] override def asNullable: MyDenseVectorUDT = this
}
class UserDefinedTypeSuite extends QueryTest {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 74b4e767ca..86fc6548f9 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -638,7 +638,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
p
} else if (childOutputDataTypes.size == tableOutputDataTypes.size &&
childOutputDataTypes.zip(tableOutputDataTypes)
- .forall { case (left, right) => DataType.equalsIgnoreNullability(left, right) }) {
+ .forall { case (left, right) => left.sameType(right) }) {
// If both types ignoring nullability of ArrayType, MapType, StructType are the same,
// use InsertIntoHiveTable instead of InsertIntoTable.
InsertIntoHiveTable(p.table, p.partition, p.child, p.overwrite)
@@ -686,8 +686,7 @@ private[hive] case class InsertIntoHiveTable(
override def output = child.output
override lazy val resolved = childrenResolved && child.output.zip(table.output).forall {
- case (childAttr, tableAttr) =>
- DataType.equalsIgnoreNullability(childAttr.dataType, tableAttr.dataType)
+ case (childAttr, tableAttr) => childAttr.dataType.sameType(tableAttr.dataType)
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
index ffaef8eef1..36bd3f8fe2 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
@@ -169,6 +169,7 @@ case class CreateMetastoreDataSourceAsSelect(
options
}
+ var existingSchema = None: Option[StructType]
if (sqlContext.catalog.tableExists(Seq(tableName))) {
// Check if we need to throw an exception or just return.
mode match {
@@ -188,22 +189,7 @@ case class CreateMetastoreDataSourceAsSelect(
val createdRelation = LogicalRelation(resolved.relation)
EliminateSubQueries(sqlContext.table(tableName).logicalPlan) match {
case l @ LogicalRelation(i: InsertableRelation) =>
- if (l.schema != createdRelation.schema) {
- val errorDescription =
- s"Cannot append to table $tableName because the schema of this " +
- s"DataFrame does not match the schema of table $tableName."
- val errorMessage =
- s"""
- |$errorDescription
- |== Schemas ==
- |${sideBySide(
- s"== Expected Schema ==" +:
- l.schema.treeString.split("\\\n"),
- s"== Actual Schema ==" +:
- createdRelation.schema.treeString.split("\\\n")).mkString("\n")}
- """.stripMargin
- throw new AnalysisException(errorMessage)
- } else if (i != createdRelation.relation) {
+ if (i != createdRelation.relation) {
val errorDescription =
s"Cannot append to table $tableName because the resolved relation does not " +
s"match the existing relation of $tableName. " +
@@ -221,6 +207,7 @@ case class CreateMetastoreDataSourceAsSelect(
""".stripMargin
throw new AnalysisException(errorMessage)
}
+ existingSchema = Some(l.schema)
case o =>
throw new AnalysisException(s"Saving data in ${o.toString} is not supported.")
}
@@ -234,15 +221,23 @@ case class CreateMetastoreDataSourceAsSelect(
createMetastoreTable = true
}
- val df = DataFrame(hiveContext, query)
+ val data = DataFrame(hiveContext, query)
+ val df = existingSchema match {
+ // If we are inserting into an existing table, just use the existing schema.
+ case Some(schema) => sqlContext.createDataFrame(data.queryExecution.toRdd, schema)
+ case None => data
+ }
// Create the relation based on the data of df.
- ResolvedDataSource(sqlContext, provider, mode, optionsWithPath, df)
+ val resolved = ResolvedDataSource(sqlContext, provider, mode, optionsWithPath, df)
if (createMetastoreTable) {
+ // We will use the schema of resolved.relation as the schema of the table (instead of
+ // the schema of df). It is important since the nullability may be changed by the relation
+ // provider (for example, see org.apache.spark.sql.parquet.DefaultSource).
hiveContext.catalog.createDataSourceTable(
tableName,
- Some(df.schema),
+ Some(resolved.relation.schema),
provider,
optionsWithPath,
isExternal)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index 868c35f35f..5d6a6f3b64 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -34,6 +34,8 @@ import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.sql.parquet.ParquetRelation2
import org.apache.spark.sql.sources.LogicalRelation
+import scala.collection.mutable.ArrayBuffer
+
/**
* Tests for persisting tables created though the data sources API into the metastore.
*/
@@ -581,7 +583,8 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
case LogicalRelation(p: ParquetRelation2) => // OK
case _ =>
fail(
- s"test_parquet_ctas should be converted to ${classOf[ParquetRelation2].getCanonicalName}")
+ "test_parquet_ctas should be converted to " +
+ s"${classOf[ParquetRelation2].getCanonicalName}")
}
// Clenup and reset confs.
@@ -592,6 +595,72 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
}
}
+ test("Pre insert nullability check (ArrayType)") {
+ val df1 =
+ createDataFrame(Tuple1(Seq(Int.box(1), null.asInstanceOf[Integer])) :: Nil).toDF("a")
+ val expectedSchema1 =
+ StructType(
+ StructField("a", ArrayType(IntegerType, containsNull = true), nullable = true) :: Nil)
+ assert(df1.schema === expectedSchema1)
+ df1.saveAsTable("arrayInParquet", "parquet", SaveMode.Overwrite)
+
+ val df2 =
+ createDataFrame(Tuple1(Seq(2, 3)) :: Nil).toDF("a")
+ val expectedSchema2 =
+ StructType(
+ StructField("a", ArrayType(IntegerType, containsNull = false), nullable = true) :: Nil)
+ assert(df2.schema === expectedSchema2)
+ df2.insertInto("arrayInParquet", overwrite = false)
+ createDataFrame(Tuple1(Seq(4, 5)) :: Nil).toDF("a")
+ .saveAsTable("arrayInParquet", SaveMode.Append) // This one internally calls df2.insertInto.
+ createDataFrame(Tuple1(Seq(Int.box(6), null.asInstanceOf[Integer])) :: Nil).toDF("a")
+ .saveAsTable("arrayInParquet", "parquet", SaveMode.Append)
+ refreshTable("arrayInParquet")
+
+ checkAnswer(
+ sql("SELECT a FROM arrayInParquet"),
+ Row(ArrayBuffer(1, null)) ::
+ Row(ArrayBuffer(2, 3)) ::
+ Row(ArrayBuffer(4, 5)) ::
+ Row(ArrayBuffer(6, null)) :: Nil)
+
+ sql("DROP TABLE arrayInParquet")
+ }
+
+ test("Pre insert nullability check (MapType)") {
+ val df1 =
+ createDataFrame(Tuple1(Map(1 -> null.asInstanceOf[Integer])) :: Nil).toDF("a")
+ val mapType1 = MapType(IntegerType, IntegerType, valueContainsNull = true)
+ val expectedSchema1 =
+ StructType(
+ StructField("a", mapType1, nullable = true) :: Nil)
+ assert(df1.schema === expectedSchema1)
+ df1.saveAsTable("mapInParquet", "parquet", SaveMode.Overwrite)
+
+ val df2 =
+ createDataFrame(Tuple1(Map(2 -> 3)) :: Nil).toDF("a")
+ val mapType2 = MapType(IntegerType, IntegerType, valueContainsNull = false)
+ val expectedSchema2 =
+ StructType(
+ StructField("a", mapType2, nullable = true) :: Nil)
+ assert(df2.schema === expectedSchema2)
+ df2.insertInto("mapInParquet", overwrite = false)
+ createDataFrame(Tuple1(Map(4 -> 5)) :: Nil).toDF("a")
+ .saveAsTable("mapInParquet", SaveMode.Append) // This one internally calls df2.insertInto.
+ createDataFrame(Tuple1(Map(6 -> null.asInstanceOf[Integer])) :: Nil).toDF("a")
+ .saveAsTable("mapInParquet", "parquet", SaveMode.Append)
+ refreshTable("mapInParquet")
+
+ checkAnswer(
+ sql("SELECT a FROM mapInParquet"),
+ Row(Map(1 -> null)) ::
+ Row(Map(2 -> 3)) ::
+ Row(Map(4 -> 5)) ::
+ Row(Map(6 -> null)) :: Nil)
+
+ sql("DROP TABLE mapInParquet")
+ }
+
test("SPARK-6024 wide schema support") {
// We will need 80 splits for this schema if the threshold is 4000.
val schema = StructType((1 to 5000).map(i => StructField(s"c_${i}", StringType, true)))
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala
index c8da8eea4e..89b943f008 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.sql.sources.{InsertIntoDataSource, LogicalRelation}
import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.types._
// The data where the partitioning key exists only in the directory structure.
case class ParquetData(intField: Int, stringField: String)
@@ -522,6 +523,34 @@ class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase {
super.afterAll()
setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString)
}
+
+ test("values in arrays and maps stored in parquet are always nullable") {
+ val df = createDataFrame(Tuple2(Map(2 -> 3), Seq(4, 5, 6)) :: Nil).toDF("m", "a")
+ val mapType1 = MapType(IntegerType, IntegerType, valueContainsNull = false)
+ val arrayType1 = ArrayType(IntegerType, containsNull = false)
+ val expectedSchema1 =
+ StructType(
+ StructField("m", mapType1, nullable = true) ::
+ StructField("a", arrayType1, nullable = true) :: Nil)
+ assert(df.schema === expectedSchema1)
+
+ df.saveAsTable("alwaysNullable", "parquet")
+
+ val mapType2 = MapType(IntegerType, IntegerType, valueContainsNull = true)
+ val arrayType2 = ArrayType(IntegerType, containsNull = true)
+ val expectedSchema2 =
+ StructType(
+ StructField("m", mapType2, nullable = true) ::
+ StructField("a", arrayType2, nullable = true) :: Nil)
+
+ assert(table("alwaysNullable").schema === expectedSchema2)
+
+ checkAnswer(
+ sql("SELECT m, a FROM alwaysNullable"),
+ Row(Map(2 -> 3), Seq(4, 5, 6)))
+
+ sql("DROP TABLE alwaysNullable")
+ }
}
class ParquetDataSourceOffSourceSuite extends ParquetSourceSuiteBase {