aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2016-05-04 14:16:57 +0800
committerCheng Lian <lian@databricks.com>2016-05-04 14:16:57 +0800
commitbc3760d405cc8c3ffcd957b188afa8b7e3b1f824 (patch)
treeaa6fae43f4eb0e9e88a0f2574bb2fa619954f98a
parent695f0e9195209c75bfc62fc70bfc6d7d9f1047b3 (diff)
downloadspark-bc3760d405cc8c3ffcd957b188afa8b7e3b1f824.tar.gz
spark-bc3760d405cc8c3ffcd957b188afa8b7e3b1f824.tar.bz2
spark-bc3760d405cc8c3ffcd957b188afa8b7e3b1f824.zip
[SPARK-14237][SQL] De-duplicate partition value appending logic in various buildReader() implementations
## What changes were proposed in this pull request? Currently, various `FileFormat` data sources share approximately the same code for partition value appending. This PR tries to eliminate this duplication. A new method `buildReaderWithPartitionValues()` is added to `FileFormat` with a default implementation that appends partition values to `InternalRow`s produced by the reader function returned by `buildReader()`. Special data sources like Parquet, which implements partition value appending inside `buildReader()` because of the vectorized reader, and the Text data source, which doesn't support partitioning, override `buildReaderWithPartitionValues()` and simply delegate to `buildReader()`. This PR brings two benefits: 1. Apparently, it de-duplicates partition value appending logic 2. Now the reader function returned by `buildReader()` is only required to produce `InternalRow`s rather than `UnsafeRow`s if the data source doesn't override `buildReaderWithPartitionValues()`. Because the safe-to-unsafe conversion is also performed while appending partition values. This makes 3rd-party data sources (e.g. spark-avro) easier to implement since they no longer need to access private APIs involving `UnsafeRow`. ## How was this patch tested? Existing tests should do the work. Author: Cheng Lian <lian@databricks.com> Closes #12866 from liancheng/spark-14237-simplify-partition-values-appending.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala40
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala3
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala11
9 files changed, 74 insertions, 53 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index ba2e1e2bc2..5f78fab4dd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -204,25 +204,10 @@ class DefaultSource extends FileFormat with DataSourceRegister {
val converter = RowEncoder(dataSchema)
- val unsafeRowIterator = points.map { pt =>
+ points.map { pt =>
val features = if (sparse) pt.features.toSparse else pt.features.toDense
converter.toRow(Row(pt.label, features))
}
-
- def toAttribute(f: StructField): AttributeReference =
- AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
-
- // Appends partition values
- val fullOutput = (dataSchema ++ partitionSchema).map(toAttribute)
- val requiredOutput = fullOutput.filter { a =>
- requiredSchema.fieldNames.contains(a.name) || partitionSchema.fieldNames.contains(a.name)
- }
- val joinedRow = new JoinedRow()
- val appendPartitionColumns = GenerateUnsafeProjection.generate(requiredOutput, fullOutput)
-
- unsafeRowIterator.map { dataRow =>
- appendPartitionColumns(joinedRow(dataRow, file.partitionValues))
- }
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
index 615906a52e..8a93c6ff9a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
@@ -106,7 +106,7 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter)
logInfo(s"Pushed Filters: ${pushedDownFilters.mkString(",")}")
- val readFile = files.fileFormat.buildReader(
+ val readFile = files.fileFormat.buildReaderWithPartitionValues(
sparkSession = files.sparkSession,
dataSchema = files.dataSchema,
partitionSchema = files.partitionSchema,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
index 75143e609a..948fac0d58 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
@@ -117,20 +117,9 @@ class DefaultSource extends FileFormat with DataSourceRegister {
CSVRelation.dropHeaderLine(file, lineIterator, csvOptions)
- val unsafeRowIterator = {
- val tokenizedIterator = new BulkCsvReader(lineIterator, csvOptions, headers)
- val parser = CSVRelation.csvParser(dataSchema, requiredSchema.fieldNames, csvOptions)
- tokenizedIterator.flatMap(parser(_).toSeq)
- }
-
- // Appends partition values
- val fullOutput = requiredSchema.toAttributes ++ partitionSchema.toAttributes
- val joinedRow = new JoinedRow()
- val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput)
-
- unsafeRowIterator.map { dataRow =>
- appendPartitionColumns(joinedRow(dataRow, file.partitionValues))
- }
+ val tokenizedIterator = new BulkCsvReader(lineIterator, csvOptions, headers)
+ val parser = CSVRelation.csvParser(dataSchema, requiredSchema.fieldNames, csvOptions)
+ tokenizedIterator.flatMap(parser(_).toSeq)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
index 0a3461151c..24e2bf6d13 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
@@ -31,6 +31,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.execution.FileRelation
import org.apache.spark.sql.sources.{BaseRelation, Filter}
import org.apache.spark.sql.types.{StringType, StructType}
@@ -239,6 +240,45 @@ trait FileFormat {
}
/**
+ * Exactly the same as [[buildReader]] except that the reader function returned by this method
+ * appends partition values to [[InternalRow]]s produced by the reader function [[buildReader]]
+ * returns.
+ */
+ private[sql] def buildReaderWithPartitionValues(
+ sparkSession: SparkSession,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ requiredSchema: StructType,
+ filters: Seq[Filter],
+ options: Map[String, String],
+ hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = {
+ val dataReader = buildReader(
+ sparkSession, dataSchema, partitionSchema, requiredSchema, filters, options, hadoopConf)
+
+ new (PartitionedFile => Iterator[InternalRow]) with Serializable {
+ private val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes
+
+ private val joinedRow = new JoinedRow()
+
+ // Using lazy val to avoid serialization
+ private lazy val appendPartitionColumns =
+ GenerateUnsafeProjection.generate(fullSchema, fullSchema)
+
+ override def apply(file: PartitionedFile): Iterator[InternalRow] = {
+ // Using local val to avoid per-row lazy val check (pre-mature optimization?...)
+ val converter = appendPartitionColumns
+
+ // Note that we have to apply the converter even though `file.partitionValues` is empty.
+ // This is because the converter is also responsible for converting safe `InternalRow`s into
+ // `UnsafeRow`s.
+ dataReader(file).map { dataRow =>
+ converter(joinedRow(dataRow, file.partitionValues))
+ }
+ }
+ }
+ }
+
+ /**
* Returns a [[OutputWriterFactory]] for generating output writers that can write data.
* This method is current used only by FileStreamSinkWriter to generate output writers that
* does not use output committers to write data. The OutputWriter generated by the returned
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
index 62446583a5..4c97abed53 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
@@ -106,22 +106,14 @@ class DefaultSource extends FileFormat with DataSourceRegister {
val columnNameOfCorruptRecord = parsedOptions.columnNameOfCorruptRecord
.getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)
- val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes
- val joinedRow = new JoinedRow()
-
(file: PartitionedFile) => {
val lines = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value).map(_.toString)
- val rows = JacksonParser.parseJson(
+ JacksonParser.parseJson(
lines,
requiredSchema,
columnNameOfCorruptRecord,
parsedOptions)
-
- val appendPartitionColumns = GenerateUnsafeProjection.generate(fullSchema, fullSchema)
- rows.map { row =>
- appendPartitionColumns(joinedRow(row, file.partitionValues))
- }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
index 79185df673..cf5c8e94f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
@@ -255,6 +255,20 @@ private[sql] class DefaultSource
schema.forall(_.dataType.isInstanceOf[AtomicType])
}
+ override private[sql] def buildReaderWithPartitionValues(
+ sparkSession: SparkSession,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ requiredSchema: StructType,
+ filters: Seq[Filter],
+ options: Map[String, String],
+ hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
+ // For Parquet data source, `buildReader` already handles partition values appending. Here we
+ // simply delegate to `buildReader`.
+ buildReader(
+ sparkSession, dataSchema, partitionSchema, requiredSchema, filters, options, hadoopConf)
+ }
+
override def buildReader(
sparkSession: SparkSession,
dataSchema: StructType,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
index 348edfcf7a..f22c0241d9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
@@ -83,6 +83,19 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
}
+ override private[sql] def buildReaderWithPartitionValues(
+ sparkSession: SparkSession,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ requiredSchema: StructType,
+ filters: Seq[Filter],
+ options: Map[String, String],
+ hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
+ // Text data source doesn't support partitioning. Here we simply delegate to `buildReader`.
+ buildReader(
+ sparkSession, dataSchema, partitionSchema, requiredSchema, filters, options, hadoopConf)
+ }
+
override def buildReader(
sparkSession: SparkSession,
dataSchema: StructType,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 07f00a0868..28e59055fa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -22,9 +22,6 @@ import java.nio.charset.UnsupportedCharsetException
import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
-import scala.collection.JavaConverters._
-
-import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.io.compress.GzipCodec
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
index d6a847f3ba..89d258e844 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
@@ -157,20 +157,11 @@ private[sql] class DefaultSource
}
// Unwraps `OrcStruct`s to `UnsafeRow`s
- val unsafeRowIterator = OrcRelation.unwrapOrcStructs(
+ OrcRelation.unwrapOrcStructs(
conf,
requiredSchema,
Some(orcRecordReader.getObjectInspector.asInstanceOf[StructObjectInspector]),
new RecordReaderIterator[OrcStruct](orcRecordReader))
-
- // Appends partition values
- val fullOutput = requiredSchema.toAttributes ++ partitionSchema.toAttributes
- val joinedRow = new JoinedRow()
- val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput)
-
- unsafeRowIterator.map { dataRow =>
- appendPartitionColumns(joinedRow(dataRow, file.partitionValues))
- }
}
}
}