From bbd8f5bee81d5788c356977c173dd1edc42c77a3 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 14 Nov 2014 14:21:16 -0800 Subject: [SPARK-4245][SQL] Fix containsNull of the result ArrayType of CreateArray expression. The `containsNull` of the result `ArrayType` of `CreateArray` should be `true` only if the children is empty or there exists nullable child. Author: Takuya UESHIN Closes #3110 from ueshin/issues/SPARK-4245 and squashes the following commits: 6f64746 [Takuya UESHIN] Move equalsIgnoreNullability method into DataType. 5a90e02 [Takuya UESHIN] Refine InsertIntoHiveType and add some comments. cbecba8 [Takuya UESHIN] Fix a test title. 884ec37 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-4245 3c5274b [Takuya UESHIN] Add tests to insert data of types ArrayType / MapType / StructType with nullability is false into Hive table. 41a94a9 [Takuya UESHIN] Replace InsertIntoTable with InsertIntoHiveTable if data types ignoring nullability are same. 43e6ef5 [Takuya UESHIN] Fix containsNull for empty array. 778e997 [Takuya UESHIN] Fix containsNull of the result ArrayType of CreateArray expression. --- .../sql/catalyst/expressions/complexTypes.scala | 4 +- .../spark/sql/catalyst/types/dataTypes.scala | 21 +++++++++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 27 ++++++++++++ .../org/apache/spark/sql/hive/HiveStrategies.scala | 6 ++- .../spark/sql/hive/InsertIntoHiveTableSuite.scala | 50 ++++++++++++++++++++++ 5 files changed, 106 insertions(+), 2 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 19421e5667..917b346086 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -115,7 +115,9 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def dataType: DataType = { assert(resolved, s"Invalid dataType of mixed ArrayType ${childTypes.mkString(",")}") - ArrayType(childTypes.headOption.getOrElse(NullType)) + ArrayType( + childTypes.headOption.getOrElse(NullType), + containsNull = children.exists(_.nullable)) } override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 5dd19dd12d..ff1dc03069 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -171,6 +171,27 @@ object DataType { case _ => } } + + /** + * Compares two types, ignoring nullability of ArrayType, MapType, StructType. + */ + def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { + (left, right) match { + case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => + equalsIgnoreNullability(leftElementType, rightElementType) + case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) => + equalsIgnoreNullability(leftKeyType, rightKeyType) && + equalsIgnoreNullability(leftValueType, rightValueType) + case (StructType(leftFields), StructType(rightFields)) => + leftFields.size == rightFields.size && + leftFields.zip(rightFields) + .forall{ + case (left, right) => + left.name == right.name && equalsIgnoreNullability(left.dataType, right.dataType) + } + case (left, right) => left == right + } + } } abstract class DataType { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index d446650422..9045fc8558 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -286,6 +286,12 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with if (childOutputDataTypes == tableOutputDataTypes) { p + } else if (childOutputDataTypes.size == tableOutputDataTypes.size && + childOutputDataTypes.zip(tableOutputDataTypes) + .forall { case (left, right) => DataType.equalsIgnoreNullability(left, right) }) { + // If both types ignoring nullability of ArrayType, MapType, StructType are the same, + // use InsertIntoHiveTable instead of InsertIntoTable. + InsertIntoHiveTable(p.table, p.partition, p.child, p.overwrite) } else { // Only do the casting when child output data types differ from table output data types. val castedChildOutput = child.output.zip(table.output).map { @@ -316,6 +322,27 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with override def unregisterAllTables() = {} } +/** + * A logical plan representing insertion into Hive table. + * This plan ignores nullability of ArrayType, MapType, StructType unlike InsertIntoTable + * because Hive table doesn't have nullability for ARRAY, MAP, STRUCT types. + */ +private[hive] case class InsertIntoHiveTable( + table: LogicalPlan, + partition: Map[String, Option[String]], + child: LogicalPlan, + overwrite: Boolean) + extends LogicalPlan { + + override def children = child :: Nil + override def output = child.output + + override lazy val resolved = childrenResolved && child.output.zip(table.output).forall { + case (childAttr, tableAttr) => + DataType.equalsIgnoreNullability(childAttr.dataType, tableAttr.dataType) + } +} + /** * :: DeveloperApi :: * Provides conversions between Spark SQL data types and Hive Metastore types. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 989740c8d4..3a49dddd85 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -161,7 +161,11 @@ private[hive] trait HiveStrategies { object DataSinks extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) => - InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil + execution.InsertIntoHiveTable( + table, partition, planLater(child), overwrite)(hiveContext) :: Nil + case hive.InsertIntoHiveTable(table: MetastoreRelation, partition, child, overwrite) => + execution.InsertIntoHiveTable( + table, partition, planLater(child), overwrite)(hiveContext) :: Nil case logical.CreateTableAsSelect( Some(database), tableName, child, allowExisting, Some(extra: ASTNode)) => CreateTableAsSelect( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 5dbfb92313..fb481edc85 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -121,4 +121,54 @@ class InsertIntoHiveTableSuite extends QueryTest { sql("DROP TABLE table_with_partition") sql("DROP TABLE tmp_table") } + + test("Insert ArrayType.containsNull == false") { + val schema = StructType(Seq( + StructField("a", ArrayType(StringType, containsNull = false)))) + val rowRDD = TestHive.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i")))) + val schemaRDD = applySchema(rowRDD, schema) + schemaRDD.registerTempTable("tableWithArrayValue") + sql("CREATE TABLE hiveTableWithArrayValue(a Array )") + sql("INSERT OVERWRITE TABLE hiveTableWithArrayValue SELECT a FROM tableWithArrayValue") + + checkAnswer( + sql("SELECT * FROM hiveTableWithArrayValue"), + rowRDD.collect().toSeq) + + sql("DROP TABLE hiveTableWithArrayValue") + } + + test("Insert MapType.valueContainsNull == false") { + val schema = StructType(Seq( + StructField("m", MapType(StringType, StringType, valueContainsNull = false)))) + val rowRDD = TestHive.sparkContext.parallelize( + (1 to 100).map(i => Row(Map(s"key$i" -> s"value$i")))) + val schemaRDD = applySchema(rowRDD, schema) + schemaRDD.registerTempTable("tableWithMapValue") + sql("CREATE TABLE hiveTableWithMapValue(m Map )") + sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") + + checkAnswer( + sql("SELECT * FROM hiveTableWithMapValue"), + rowRDD.collect().toSeq) + + sql("DROP TABLE hiveTableWithMapValue") + } + + test("Insert StructType.fields.exists(_.nullable == false)") { + val schema = StructType(Seq( + StructField("s", StructType(Seq(StructField("f", StringType, nullable = false)))))) + val rowRDD = TestHive.sparkContext.parallelize( + (1 to 100).map(i => Row(Row(s"value$i")))) + val schemaRDD = applySchema(rowRDD, schema) + schemaRDD.registerTempTable("tableWithStructValue") + sql("CREATE TABLE hiveTableWithStructValue(s Struct )") + sql("INSERT OVERWRITE TABLE hiveTableWithStructValue SELECT s FROM tableWithStructValue") + + checkAnswer( + sql("SELECT * FROM hiveTableWithStructValue"), + rowRDD.collect().toSeq) + + sql("DROP TABLE hiveTableWithStructValue") + } } -- cgit v1.2.3