diff options
Diffstat (limited to 'sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala | 73 |
1 files changed, 41 insertions, 32 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala index 54e4c1a2c9..06a371b88b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala @@ -19,25 +19,27 @@ package org.apache.spark.sql.execution.datasources.csv import java.nio.charset.{Charset, StandardCharsets} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat -import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.mapreduce._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection -import org.apache.spark.sql.execution.datasources.CompressionCodecs +import org.apache.spark.sql.catalyst.expressions.{JoinedRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.datasources.{CompressionCodecs, HadoopFileLinesReader, PartitionedFile} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.apache.spark.util.SerializableConfiguration import org.apache.spark.util.collection.BitSet /** - * Provides access to CSV data from pure SQL statements. - */ + * Provides access to CSV data from pure SQL statements. + */ class DefaultSource extends FileFormat with DataSourceRegister { override def shortName(): String = "csv" @@ -91,39 +93,46 @@ class DefaultSource extends FileFormat with DataSourceRegister { new CSVOutputWriterFactory(csvOptions) } - /** - * This supports to eliminate unneeded columns before producing an RDD - * containing all of its tuples as Row objects. This reads all the tokens of each line - * and then drop unneeded tokens without casting and type-checking by mapping - * both the indices produced by `requiredColumns` and the ones of tokens. - */ - override def buildInternalScan( + override def buildReader( sqlContext: SQLContext, dataSchema: StructType, - requiredColumns: Array[String], - filters: Array[Filter], - bucketSet: Option[BitSet], - inputFiles: Seq[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration], - options: Map[String, String]): RDD[InternalRow] = { - // TODO: Filter before calling buildInternalScan. - val csvFiles = inputFiles.filterNot(_.getPath.getName startsWith "_") - + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String]): (PartitionedFile) => Iterator[InternalRow] = { val csvOptions = new CSVOptions(options) - val pathsString = csvFiles.map(_.getPath.toUri.toString) - val header = dataSchema.fields.map(_.name) - val tokenizedRdd = tokenRdd(sqlContext, csvOptions, header, pathsString) - val rows = CSVRelation.parseCsv( - tokenizedRdd, dataSchema, requiredColumns, csvFiles, sqlContext, csvOptions) - - val requiredDataSchema = StructType(requiredColumns.map(c => dataSchema.find(_.name == c).get)) - rows.mapPartitions { iterator => - val unsafeProjection = UnsafeProjection.create(requiredDataSchema) - iterator.map(unsafeProjection) + val headers = requiredSchema.fields.map(_.name) + + val conf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) + val broadcastedConf = sqlContext.sparkContext.broadcast(new SerializableConfiguration(conf)) + + (file: PartitionedFile) => { + val lineIterator = { + val conf = broadcastedConf.value.value + new HadoopFileLinesReader(file, conf).map { line => + new String(line.getBytes, 0, line.getLength, csvOptions.charset) + } + } + + CSVRelation.dropHeaderLine(file, lineIterator, csvOptions) + + val unsafeRowIterator = { + val tokenizedIterator = new BulkCsvReader(lineIterator, csvOptions, headers) + val parser = CSVRelation.csvParser(dataSchema, requiredSchema.fieldNames, csvOptions) + tokenizedIterator.flatMap(parser(_).toSeq) + } + + // Appends partition values + val fullOutput = requiredSchema.toAttributes ++ partitionSchema.toAttributes + val joinedRow = new JoinedRow() + val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput) + + unsafeRowIterator.map { dataRow => + appendPartitionColumns(joinedRow(dataRow, file.partitionValues)) + } } } - private def baseRdd( sqlContext: SQLContext, options: CSVOptions, |