aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-01-12 15:19:09 -0800
committerMichael Armbrust <michael@databricks.com>2015-01-12 15:19:35 -0800
commit6d23af6a4d3b864036fef2c80022d438de88cfb5 (patch)
treebe1723ee80c000eaeaa941370d8e7f94d0144283
parent5970f0bbc7ff31df3f8d2c6dc0b46cd9f63ebe9a (diff)
downloadspark-6d23af6a4d3b864036fef2c80022d438de88cfb5.tar.gz
spark-6d23af6a4d3b864036fef2c80022d438de88cfb5.tar.bz2
spark-6d23af6a4d3b864036fef2c80022d438de88cfb5.zip
[SPARK-5049][SQL] Fix ordering of partition columns in ParquetTableScan
Followup to #3870. Props to rahulaggarwalguavus for identifying the issue. Author: Michael Armbrust <michael@databricks.com> Closes #3990 from marmbrus/SPARK-5049 and squashes the following commits: dd03e4e [Michael Armbrust] Fill in the partition values of parquet scans instead of using JoinedRow (cherry picked from commit 5d9fa550820543ee1b0ce82997917745973a5d65) Signed-off-by: Michael Armbrust <michael@databricks.com>
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala43
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala12
3 files changed, 41 insertions, 18 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
index b237a07c72..2835dc3408 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
@@ -28,7 +28,7 @@ import parquet.schema.MessageType
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException}
-import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
/**
@@ -67,6 +67,8 @@ private[sql] case class ParquetRelation(
conf,
sqlContext.isParquetBinaryAsString)
+ lazy val attributeMap = AttributeMap(output.map(o => o -> o))
+
override def newInstance() = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type]
// Equals must also take into account the output attributes so that we can distinguish between
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index 232ef90b01..072a4bcc42 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -63,18 +63,17 @@ case class ParquetTableScan(
// The resolution of Parquet attributes is case sensitive, so we resolve the original attributes
// by exprId. note: output cannot be transient, see
// https://issues.apache.org/jira/browse/SPARK-1367
- val normalOutput =
- attributes
- .filterNot(a => relation.partitioningAttributes.map(_.exprId).contains(a.exprId))
- .flatMap(a => relation.output.find(o => o.exprId == a.exprId))
+ val output = attributes.map(relation.attributeMap)
- val partOutput =
- attributes.flatMap(a => relation.partitioningAttributes.find(o => o.exprId == a.exprId))
+ // A mapping of ordinals partitionRow -> finalOutput.
+ val requestedPartitionOrdinals = {
+ val partitionAttributeOrdinals = AttributeMap(relation.partitioningAttributes.zipWithIndex)
- def output = partOutput ++ normalOutput
-
- assert(normalOutput.size + partOutput.size == attributes.size,
- s"$normalOutput + $partOutput != $attributes, ${relation.output}")
+ attributes.zipWithIndex.flatMap {
+ case (attribute, finalOrdinal) =>
+ partitionAttributeOrdinals.get(attribute).map(_ -> finalOrdinal)
+ }
+ }.toArray
override def execute(): RDD[Row] = {
import parquet.filter2.compat.FilterCompat.FilterPredicateCompat
@@ -96,7 +95,7 @@ case class ParquetTableScan(
// Store both requested and original schema in `Configuration`
conf.set(
RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA,
- ParquetTypesConverter.convertToString(normalOutput))
+ ParquetTypesConverter.convertToString(output))
conf.set(
RowWriteSupport.SPARK_ROW_SCHEMA,
ParquetTypesConverter.convertToString(relation.output))
@@ -124,7 +123,7 @@ case class ParquetTableScan(
classOf[Row],
conf)
- if (partOutput.nonEmpty) {
+ if (requestedPartitionOrdinals.nonEmpty) {
baseRDD.mapPartitionsWithInputSplit { case (split, iter) =>
val partValue = "([^=]+)=([^=]+)".r
val partValues =
@@ -137,15 +136,25 @@ case class ParquetTableScan(
case _ => None
}.toMap
+ // Convert the partitioning attributes into the correct types
val partitionRowValues =
- partOutput.map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow))
+ relation.partitioningAttributes
+ .map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow))
new Iterator[Row] {
- private[this] val joinedRow = new JoinedRow5(Row(partitionRowValues:_*), null)
-
def hasNext = iter.hasNext
-
- def next() = joinedRow.withRight(iter.next()._2)
+ def next() = {
+ val row = iter.next()._2.asInstanceOf[SpecificMutableRow]
+
+ // Parquet will leave partitioning columns empty, so we fill them in here.
+ var i = 0
+ while (i < requestedPartitionOrdinals.size) {
+ row(requestedPartitionOrdinals(i)._2) =
+ partitionRowValues(requestedPartitionOrdinals(i)._1)
+ i += 1
+ }
+ row
+ }
}
}
} else {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala
index 488ebba043..06fe144666 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala
@@ -174,6 +174,18 @@ abstract class ParquetTest extends QueryTest with BeforeAndAfterAll {
}
Seq("partitioned_parquet", "partitioned_parquet_with_key").foreach { table =>
+ test(s"ordering of the partitioning columns $table") {
+ checkAnswer(
+ sql(s"SELECT p, stringField FROM $table WHERE p = 1"),
+ Seq.fill(10)((1, "part-1"))
+ )
+
+ checkAnswer(
+ sql(s"SELECT stringField, p FROM $table WHERE p = 1"),
+ Seq.fill(10)(("part-1", 1))
+ )
+ }
+
test(s"project the partitioning column $table") {
checkAnswer(
sql(s"SELECT p, count(*) FROM $table group by p"),