aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala73
1 files changed, 41 insertions, 32 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
index 54e4c1a2c9..06a371b88b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
@@ -19,25 +19,27 @@ package org.apache.spark.sql.execution.datasources.csv
import java.nio.charset.{Charset, StandardCharsets}
+import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FileStatus
import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapred.TextInputFormat
-import org.apache.hadoop.mapreduce.Job
+import org.apache.hadoop.mapreduce._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
-import org.apache.spark.sql.execution.datasources.CompressionCodecs
+import org.apache.spark.sql.catalyst.expressions.{JoinedRow, UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.execution.datasources.{CompressionCodecs, HadoopFileLinesReader, PartitionedFile}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.util.SerializableConfiguration
import org.apache.spark.util.collection.BitSet
/**
- * Provides access to CSV data from pure SQL statements.
- */
+ * Provides access to CSV data from pure SQL statements.
+ */
class DefaultSource extends FileFormat with DataSourceRegister {
override def shortName(): String = "csv"
@@ -91,39 +93,46 @@ class DefaultSource extends FileFormat with DataSourceRegister {
new CSVOutputWriterFactory(csvOptions)
}
- /**
- * This supports to eliminate unneeded columns before producing an RDD
- * containing all of its tuples as Row objects. This reads all the tokens of each line
- * and then drop unneeded tokens without casting and type-checking by mapping
- * both the indices produced by `requiredColumns` and the ones of tokens.
- */
- override def buildInternalScan(
+ override def buildReader(
sqlContext: SQLContext,
dataSchema: StructType,
- requiredColumns: Array[String],
- filters: Array[Filter],
- bucketSet: Option[BitSet],
- inputFiles: Seq[FileStatus],
- broadcastedConf: Broadcast[SerializableConfiguration],
- options: Map[String, String]): RDD[InternalRow] = {
- // TODO: Filter before calling buildInternalScan.
- val csvFiles = inputFiles.filterNot(_.getPath.getName startsWith "_")
-
+ partitionSchema: StructType,
+ requiredSchema: StructType,
+ filters: Seq[Filter],
+ options: Map[String, String]): (PartitionedFile) => Iterator[InternalRow] = {
val csvOptions = new CSVOptions(options)
- val pathsString = csvFiles.map(_.getPath.toUri.toString)
- val header = dataSchema.fields.map(_.name)
- val tokenizedRdd = tokenRdd(sqlContext, csvOptions, header, pathsString)
- val rows = CSVRelation.parseCsv(
- tokenizedRdd, dataSchema, requiredColumns, csvFiles, sqlContext, csvOptions)
-
- val requiredDataSchema = StructType(requiredColumns.map(c => dataSchema.find(_.name == c).get))
- rows.mapPartitions { iterator =>
- val unsafeProjection = UnsafeProjection.create(requiredDataSchema)
- iterator.map(unsafeProjection)
+ val headers = requiredSchema.fields.map(_.name)
+
+ val conf = new Configuration(sqlContext.sparkContext.hadoopConfiguration)
+ val broadcastedConf = sqlContext.sparkContext.broadcast(new SerializableConfiguration(conf))
+
+ (file: PartitionedFile) => {
+ val lineIterator = {
+ val conf = broadcastedConf.value.value
+ new HadoopFileLinesReader(file, conf).map { line =>
+ new String(line.getBytes, 0, line.getLength, csvOptions.charset)
+ }
+ }
+
+ CSVRelation.dropHeaderLine(file, lineIterator, csvOptions)
+
+ val unsafeRowIterator = {
+ val tokenizedIterator = new BulkCsvReader(lineIterator, csvOptions, headers)
+ val parser = CSVRelation.csvParser(dataSchema, requiredSchema.fieldNames, csvOptions)
+ tokenizedIterator.flatMap(parser(_).toSeq)
+ }
+
+ // Appends partition values
+ val fullOutput = requiredSchema.toAttributes ++ partitionSchema.toAttributes
+ val joinedRow = new JoinedRow()
+ val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput)
+
+ unsafeRowIterator.map { dataRow =>
+ appendPartitionColumns(joinedRow(dataRow, file.partitionValues))
+ }
}
}
-
private def baseRdd(
sqlContext: SQLContext,
options: CSVOptions,