aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala15
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala25
3 files changed, 35 insertions, 10 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index b3a197cd96..7afdf75f38 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.datasources.parquet.{DefaultSource => Parq
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.BaseRelation
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types.{DataType, StructType}
object RDDConversions {
def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = {
@@ -348,7 +348,8 @@ private[sql] object DataSourceScanExec {
}
relation match {
- case r: HadoopFsRelation if r.fileFormat.supportBatch(r.sqlContext, relation.schema) =>
+ case r: HadoopFsRelation
+ if r.fileFormat.supportBatch(r.sqlContext, StructType.fromAttributes(output)) =>
BatchedDataSourceScanExec(output, rdd, relation, outputPartitioning, metadata)
case _ =>
RowDataSourceScanExec(output, rdd, relation, outputPartitioning, metadata)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
index 38c0084952..bbbbc5ebe9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
@@ -286,10 +286,6 @@ private[sql] class DefaultSource
SQLConf.PARQUET_INT96_AS_TIMESTAMP.key,
sqlContext.conf.getConf(SQLConf.PARQUET_INT96_AS_TIMESTAMP))
- // Whole stage codegen (PhysicalRDD) is able to deal with batches directly
- val returningBatch =
- supportBatch(sqlContext, StructType(partitionSchema.fields ++ dataSchema.fields))
-
// Try to push down filters when filter push-down is enabled.
val pushed = if (sqlContext.getConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key).toBoolean) {
filters
@@ -308,8 +304,11 @@ private[sql] class DefaultSource
// TODO: if you move this into the closure it reverts to the default values.
// If true, enable using the custom RecordReader for parquet. This only works for
// a subset of the types (no complex types).
- val enableVectorizedParquetReader: Boolean = sqlContext.conf.parquetVectorizedReaderEnabled &&
- dataSchema.forall(_.dataType.isInstanceOf[AtomicType])
+ val resultSchema = StructType(partitionSchema.fields ++ requiredSchema.fields)
+ val enableVectorizedReader: Boolean = sqlContext.conf.parquetVectorizedReaderEnabled &&
+ resultSchema.forall(_.dataType.isInstanceOf[AtomicType])
+ // Whole stage codegen (PhysicalRDD) is able to deal with batches directly
+ val returningBatch = supportBatch(sqlContext, resultSchema)
(file: PartitionedFile) => {
assert(file.partitionValues.numFields == partitionSchema.size)
@@ -329,7 +328,7 @@ private[sql] class DefaultSource
val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
val hadoopAttemptContext = new TaskAttemptContextImpl(broadcastedConf.value.value, attemptId)
- val parquetReader = if (enableVectorizedParquetReader) {
+ val parquetReader = if (enableVectorizedReader) {
val vectorizedReader = new VectorizedParquetRecordReader()
vectorizedReader.initialize(split, hadoopAttemptContext)
logDebug(s"Appending $partitionSchema ${file.partitionValues}")
@@ -356,7 +355,7 @@ private[sql] class DefaultSource
// UnsafeRowParquetRecordReader appends the columns internally to avoid another copy.
if (parquetReader.isInstanceOf[VectorizedParquetRecordReader] &&
- enableVectorizedParquetReader) {
+ enableVectorizedReader) {
iter.asInstanceOf[Iterator[InternalRow]]
} else {
val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
index 7d206e7bc4..ed20c45d5f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
@@ -24,6 +24,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
+import org.apache.spark.sql.execution.BatchedDataSourceScanExec
import org.apache.spark.sql.execution.datasources.parquet.TestingUDT.{NestedStruct, NestedStructUDT}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
@@ -589,6 +590,30 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
checkAnswer(sqlContext.read.parquet(path), df)
}
}
+
+ test("returning batch for wide table") {
+ withSQLConf("spark.sql.codegen.maxFields" -> "100") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ val df = sqlContext.range(100).select(Seq.tabulate(110) {i => ('id + i).as(s"c$i")} : _*)
+ df.write.mode(SaveMode.Overwrite).parquet(path)
+
+ // donot return batch, because whole stage codegen is disabled for wide table (>200 columns)
+ val df2 = sqlContext.read.parquet(path)
+ assert(df2.queryExecution.sparkPlan.find(_.isInstanceOf[BatchedDataSourceScanExec]).isEmpty,
+ "Should not return batch")
+ checkAnswer(df2, df)
+
+ // return batch
+ val columns = Seq.tabulate(90) {i => s"c$i"}
+ val df3 = df2.selectExpr(columns : _*)
+ assert(
+ df3.queryExecution.sparkPlan.find(_.isInstanceOf[BatchedDataSourceScanExec]).isDefined,
+ "Should not return batch")
+ checkAnswer(df3, df.selectExpr(columns : _*))
+ }
+ }
+ }
}
object TestingUDT {