aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala34
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala37
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala43
6 files changed, 117 insertions, 29 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala
index 9e0f9943bc..66f123682e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala
@@ -273,4 +273,9 @@ class MetadataBuilder {
map.put(key, value)
this
}
+
+ def remove(key: String): this.type = {
+ map.remove(key)
+ this
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index 3bd733fa2d..da0c92864e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -334,6 +334,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
object StructType extends AbstractDataType {
+ private[sql] val metadataKeyForOptionalField = "_OPTIONAL_"
+
override private[sql] def defaultConcreteType: DataType = new StructType
override private[sql] def acceptsType(other: DataType): Boolean = {
@@ -359,6 +361,18 @@ object StructType extends AbstractDataType {
protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType =
StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))
+ def removeMetadata(key: String, dt: DataType): DataType =
+ dt match {
+ case StructType(fields) =>
+ val newFields = fields.map { f =>
+ val mb = new MetadataBuilder()
+ f.copy(dataType = removeMetadata(key, f.dataType),
+ metadata = mb.withMetadata(f.metadata).remove(key).build())
+ }
+ StructType(newFields)
+ case _ => dt
+ }
+
private[sql] def merge(left: DataType, right: DataType): DataType =
(left, right) match {
case (ArrayType(leftElementType, leftContainsNull),
@@ -376,24 +390,32 @@ object StructType extends AbstractDataType {
case (StructType(leftFields), StructType(rightFields)) =>
val newFields = ArrayBuffer.empty[StructField]
+ // This metadata will record the fields that only exist in one of two StructTypes
+ val optionalMeta = new MetadataBuilder()
val rightMapped = fieldsMap(rightFields)
leftFields.foreach {
case leftField @ StructField(leftName, leftType, leftNullable, _) =>
rightMapped.get(leftName)
.map { case rightField @ StructField(_, rightType, rightNullable, _) =>
- leftField.copy(
- dataType = merge(leftType, rightType),
- nullable = leftNullable || rightNullable)
- }
- .orElse(Some(leftField))
+ leftField.copy(
+ dataType = merge(leftType, rightType),
+ nullable = leftNullable || rightNullable)
+ }
+ .orElse {
+ optionalMeta.putBoolean(metadataKeyForOptionalField, true)
+ Some(leftField.copy(metadata = optionalMeta.build()))
+ }
.foreach(newFields += _)
}
val leftMapped = fieldsMap(leftFields)
rightFields
.filterNot(f => leftMapped.get(f.name).nonEmpty)
- .foreach(newFields += _)
+ .foreach { f =>
+ optionalMeta.putBoolean(metadataKeyForOptionalField, true)
+ newFields += f.copy(metadata = optionalMeta.build())
+ }
StructType(newFields)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index 706ecd29d1..c2bbca7c33 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -122,7 +122,9 @@ class DataTypeSuite extends SparkFunSuite {
val right = StructType(List())
val merged = left.merge(right)
- assert(merged === left)
+ assert(DataType.equalsIgnoreCompatibleNullability(merged, left))
+ assert(merged("a").metadata.getBoolean(StructType.metadataKeyForOptionalField))
+ assert(merged("b").metadata.getBoolean(StructType.metadataKeyForOptionalField))
}
test("merge where left is empty") {
@@ -135,8 +137,9 @@ class DataTypeSuite extends SparkFunSuite {
val merged = left.merge(right)
- assert(right === merged)
-
+ assert(DataType.equalsIgnoreCompatibleNullability(merged, right))
+ assert(merged("a").metadata.getBoolean(StructType.metadataKeyForOptionalField))
+ assert(merged("b").metadata.getBoolean(StructType.metadataKeyForOptionalField))
}
test("merge where both are non-empty") {
@@ -154,7 +157,10 @@ class DataTypeSuite extends SparkFunSuite {
val merged = left.merge(right)
- assert(merged === expected)
+ assert(DataType.equalsIgnoreCompatibleNullability(merged, expected))
+ assert(merged("a").metadata.getBoolean(StructType.metadataKeyForOptionalField))
+ assert(merged("b").metadata.getBoolean(StructType.metadataKeyForOptionalField))
+ assert(merged("c").metadata.getBoolean(StructType.metadataKeyForOptionalField))
}
test("merge where right contains type conflict") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
index e9b734b0ab..5a5cb5cf03 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
@@ -208,10 +208,25 @@ private[sql] object ParquetFilters {
}
/**
+ * SPARK-11955: The optional fields will have metadata StructType.metadataKeyForOptionalField.
+ * These fields only exist in one side of merged schemas. Due to that, we can't push down filters
+ * using such fields, otherwise Parquet library will throw exception. Here we filter out such
+ * fields.
+ */
+ private def getFieldMap(dataType: DataType): Array[(String, DataType)] = dataType match {
+ case StructType(fields) =>
+ fields.filter { f =>
+ !f.metadata.contains(StructType.metadataKeyForOptionalField) ||
+ !f.metadata.getBoolean(StructType.metadataKeyForOptionalField)
+ }.map(f => f.name -> f.dataType) ++ fields.flatMap { f => getFieldMap(f.dataType) }
+ case _ => Array.empty[(String, DataType)]
+ }
+
+ /**
* Converts data sources filters to Parquet filter predicates.
*/
def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = {
- val dataTypeOf = schema.map(f => f.name -> f.dataType).toMap
+ val dataTypeOf = getFieldMap(schema).toMap
relaxParquetValidTypeMap
@@ -231,29 +246,29 @@ private[sql] object ParquetFilters {
// Probably I missed something and obviously this should be changed.
predicate match {
- case sources.IsNull(name) =>
+ case sources.IsNull(name) if dataTypeOf.contains(name) =>
makeEq.lift(dataTypeOf(name)).map(_(name, null))
- case sources.IsNotNull(name) =>
+ case sources.IsNotNull(name) if dataTypeOf.contains(name) =>
makeNotEq.lift(dataTypeOf(name)).map(_(name, null))
- case sources.EqualTo(name, value) =>
+ case sources.EqualTo(name, value) if dataTypeOf.contains(name) =>
makeEq.lift(dataTypeOf(name)).map(_(name, value))
- case sources.Not(sources.EqualTo(name, value)) =>
+ case sources.Not(sources.EqualTo(name, value)) if dataTypeOf.contains(name) =>
makeNotEq.lift(dataTypeOf(name)).map(_(name, value))
- case sources.EqualNullSafe(name, value) =>
+ case sources.EqualNullSafe(name, value) if dataTypeOf.contains(name) =>
makeEq.lift(dataTypeOf(name)).map(_(name, value))
- case sources.Not(sources.EqualNullSafe(name, value)) =>
+ case sources.Not(sources.EqualNullSafe(name, value)) if dataTypeOf.contains(name) =>
makeNotEq.lift(dataTypeOf(name)).map(_(name, value))
- case sources.LessThan(name, value) =>
+ case sources.LessThan(name, value) if dataTypeOf.contains(name) =>
makeLt.lift(dataTypeOf(name)).map(_(name, value))
- case sources.LessThanOrEqual(name, value) =>
+ case sources.LessThanOrEqual(name, value) if dataTypeOf.contains(name) =>
makeLtEq.lift(dataTypeOf(name)).map(_(name, value))
- case sources.GreaterThan(name, value) =>
+ case sources.GreaterThan(name, value) if dataTypeOf.contains(name) =>
makeGt.lift(dataTypeOf(name)).map(_(name, value))
- case sources.GreaterThanOrEqual(name, value) =>
+ case sources.GreaterThanOrEqual(name, value) if dataTypeOf.contains(name) =>
makeGtEq.lift(dataTypeOf(name)).map(_(name, value))
case sources.In(name, valueSet) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
index b460ec1d26..f87590095d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
@@ -258,7 +258,12 @@ private[sql] class ParquetRelation(
job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]])
ParquetOutputFormat.setWriteSupportClass(job, classOf[CatalystWriteSupport])
- CatalystWriteSupport.setSchema(dataSchema, conf)
+
+ // We want to clear this temporary metadata from saving into Parquet file.
+ // This metadata is only useful for detecting optional columns when pushdowning filters.
+ val dataSchemaToWrite = StructType.removeMetadata(StructType.metadataKeyForOptionalField,
+ dataSchema).asInstanceOf[StructType]
+ CatalystWriteSupport.setSchema(dataSchemaToWrite, conf)
// Sets flags for `CatalystSchemaConverter` (which converts Catalyst schema to Parquet schema)
// and `CatalystWriteSupport` (writing actual rows to Parquet files).
@@ -304,10 +309,6 @@ private[sql] class ParquetRelation(
val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString
val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp
- // When merging schemas is enabled and the column of the given filter does not exist,
- // Parquet emits an exception which is an issue of Parquet (PARQUET-389).
- val safeParquetFilterPushDown = !shouldMergeSchemas && parquetFilterPushDown
-
// Parquet row group size. We will use this value as the value for
// mapreduce.input.fileinputformat.split.minsize and mapred.min.split.size if the value
// of these flags are smaller than the parquet row group size.
@@ -321,7 +322,7 @@ private[sql] class ParquetRelation(
dataSchema,
parquetBlockSize,
useMetadataCache,
- safeParquetFilterPushDown,
+ parquetFilterPushDown,
assumeBinaryIsString,
assumeInt96IsTimestamp) _
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index 97c5313f0f..1796b3af0e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelation}
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@@ -379,9 +380,47 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
// If the "c = 1" filter gets pushed down, this query will throw an exception which
// Parquet emits. This is a Parquet issue (PARQUET-389).
+ val df = sqlContext.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a")
checkAnswer(
- sqlContext.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a"),
- (1 to 1).map(i => Row(i, i.toString, null)))
+ df,
+ Row(1, "1", null))
+
+ // The fields "a" and "c" only exist in one Parquet file.
+ assert(df.schema("a").metadata.getBoolean(StructType.metadataKeyForOptionalField))
+ assert(df.schema("c").metadata.getBoolean(StructType.metadataKeyForOptionalField))
+
+ val pathThree = s"${dir.getCanonicalPath}/table3"
+ df.write.parquet(pathThree)
+
+ // We will remove the temporary metadata when writing Parquet file.
+ val schema = sqlContext.read.parquet(pathThree).schema
+ assert(schema.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField)))
+
+ val pathFour = s"${dir.getCanonicalPath}/table4"
+ val dfStruct = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
+ dfStruct.select(struct("a").as("s")).write.parquet(pathFour)
+
+ val pathFive = s"${dir.getCanonicalPath}/table5"
+ val dfStruct2 = sparkContext.parallelize(Seq((1, 1))).toDF("c", "b")
+ dfStruct2.select(struct("c").as("s")).write.parquet(pathFive)
+
+ // If the "s.c = 1" filter gets pushed down, this query will throw an exception which
+ // Parquet emits.
+ val dfStruct3 = sqlContext.read.parquet(pathFour, pathFive).filter("s.c = 1")
+ .selectExpr("s")
+ checkAnswer(dfStruct3, Row(Row(null, 1)))
+
+ // The fields "s.a" and "s.c" only exist in one Parquet file.
+ val field = dfStruct3.schema("s").dataType.asInstanceOf[StructType]
+ assert(field("a").metadata.getBoolean(StructType.metadataKeyForOptionalField))
+ assert(field("c").metadata.getBoolean(StructType.metadataKeyForOptionalField))
+
+ val pathSix = s"${dir.getCanonicalPath}/table6"
+ dfStruct3.write.parquet(pathSix)
+
+ // We will remove the temporary metadata when writing Parquet file.
+ val forPathSix = sqlContext.read.parquet(pathSix).schema
+ assert(forPathSix.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField)))
}
}
}