diff options
Diffstat (limited to 'sql')
11 files changed, 27 insertions, 63 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 7e23260e65..b7f3559b65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -466,7 +466,7 @@ case class DataSource( // SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does // not need to have the query as child, to avoid to analyze an optimized query, // because InsertIntoHadoopFsRelationCommand will be optimized first. - val columns = partitionColumns.map { name => + val partitionAttributes = partitionColumns.map { name => val plan = data.logicalPlan plan.resolve(name :: Nil, data.sparkSession.sessionState.analyzer.resolver).getOrElse { throw new AnalysisException( @@ -485,7 +485,7 @@ case class DataSource( InsertIntoHadoopFsRelationCommand( outputPath = outputPath, staticPartitions = Map.empty, - partitionColumns = columns, + partitionColumns = partitionAttributes, bucketSpec = bucketSpec, fileFormat = format, options = options, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 1eb4541e2c..16c5193eda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -64,18 +64,18 @@ object FileFormatWriter extends Logging { val outputWriterFactory: OutputWriterFactory, val allColumns: Seq[Attribute], val partitionColumns: Seq[Attribute], - val nonPartitionColumns: Seq[Attribute], + val dataColumns: Seq[Attribute], val bucketSpec: Option[BucketSpec], val path: String, val customPartitionLocations: Map[TablePartitionSpec, String], val maxRecordsPerFile: Long) extends Serializable { - assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns), + assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns), s""" |All columns: ${allColumns.mkString(", ")} |Partition columns: ${partitionColumns.mkString(", ")} - |Non-partition columns: ${nonPartitionColumns.mkString(", ")} + |Data columns: ${dataColumns.mkString(", ")} """.stripMargin) } @@ -120,7 +120,7 @@ object FileFormatWriter extends Logging { outputWriterFactory = outputWriterFactory, allColumns = queryExecution.logical.output, partitionColumns = partitionColumns, - nonPartitionColumns = dataColumns, + dataColumns = dataColumns, bucketSpec = bucketSpec, path = outputSpec.outputPath, customPartitionLocations = outputSpec.customPartitionLocations, @@ -246,9 +246,8 @@ object FileFormatWriter extends Logging { currentWriter = description.outputWriterFactory.newInstance( path = tmpFilePath, - dataSchema = description.nonPartitionColumns.toStructType, + dataSchema = description.dataColumns.toStructType, context = taskAttemptContext) - currentWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType) } override def execute(iter: Iterator[InternalRow]): Set[String] = { @@ -267,7 +266,7 @@ object FileFormatWriter extends Logging { } val internalRow = iter.next() - currentWriter.writeInternal(internalRow) + currentWriter.write(internalRow) recordsInFile += 1 } releaseResources() @@ -364,9 +363,8 @@ object FileFormatWriter extends Logging { currentWriter = description.outputWriterFactory.newInstance( path = path, - dataSchema = description.nonPartitionColumns.toStructType, + dataSchema = description.dataColumns.toStructType, context = taskAttemptContext) - currentWriter.initConverter(description.nonPartitionColumns.toStructType) } override def execute(iter: Iterator[InternalRow]): Set[String] = { @@ -383,7 +381,7 @@ object FileFormatWriter extends Logging { // Returns the data columns to be written given an input row val getOutputRow = UnsafeProjection.create( - description.nonPartitionColumns, description.allColumns) + description.dataColumns, description.allColumns) // Returns the partition path given a partition key. val getPartitionStringFunc = UnsafeProjection.create( @@ -392,7 +390,7 @@ object FileFormatWriter extends Logging { // Sorts the data before write, so that we only need one writer at the same time. val sorter = new UnsafeKVExternalSorter( sortingKeySchema, - StructType.fromAttributes(description.nonPartitionColumns), + StructType.fromAttributes(description.dataColumns), SparkEnv.get.blockManager, SparkEnv.get.serializerManager, TaskContext.get().taskMemoryManager().pageSizeBytes, @@ -448,7 +446,7 @@ object FileFormatWriter extends Logging { newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) } - currentWriter.writeInternal(sortedIterator.getValue) + currentWriter.write(sortedIterator.getValue) recordsInFile += 1 } releaseResources() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 84ea58b68a..423009e4ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -45,7 +45,7 @@ case class InsertIntoHadoopFsRelationCommand( bucketSpec: Option[BucketSpec], fileFormat: FileFormat, options: Map[String, String], - @transient query: LogicalPlan, + query: LogicalPlan, mode: SaveMode, catalogTable: Option[CatalogTable], fileIndex: Option[FileIndex]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala index a73c8146c1..868e537142 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala @@ -47,19 +47,6 @@ abstract class OutputWriterFactory extends Serializable { path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter - - /** - * Returns a new instance of [[OutputWriter]] that will write data to the given path. - * This method gets called by each task on executor to write InternalRows to - * format-specific files. Compared to the other `newInstance()`, this is a newer API that - * passes only the path that the writer must write to. The writer must write to the exact path - * and not modify it (do not add subdirectories, extensions, etc.). All other - * file-format-specific information needed to create the writer must be passed - * through the [[OutputWriterFactory]] implementation. - */ - def newWriter(path: String): OutputWriter = { - throw new UnsupportedOperationException("newInstance with just path not supported") - } } @@ -74,22 +61,11 @@ abstract class OutputWriter { * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned * tables, dynamic partition columns are not included in rows to be written. */ - def write(row: Row): Unit + def write(row: InternalRow): Unit /** * Closes the [[OutputWriter]]. Invoked on the executor side after all rows are persisted, before * the task output is committed. */ def close(): Unit - - private var converter: InternalRow => Row = _ - - protected[sql] def initConverter(dataSchema: StructType) = { - converter = - CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] - } - - protected[sql] def writeInternal(row: InternalRow): Unit = { - write(converter(row)) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 23c07eb630..8c19be48c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -221,9 +221,7 @@ private[csv] class CsvOutputWriter( row.get(ordinal, dt).toString } - override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") - - override protected[sql] def writeInternal(row: InternalRow): Unit = { + override def write(row: InternalRow): Unit = { csvWriter.writeRow(rowToString(row), printHeader) printHeader = false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index a9d8ddfe9d..be1f94dbad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -159,9 +159,7 @@ private[json] class JsonOutputWriter( // create the Generator without separator inserted between 2 records private[this] val gen = new JacksonGenerator(dataSchema, writer, options) - override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") - - override protected[sql] def writeInternal(row: InternalRow): Unit = { + override def write(row: InternalRow): Unit = { gen.write(row) gen.writeLineEnding() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala index 5c0f8af17a..8361762b09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala @@ -37,9 +37,7 @@ private[parquet] class ParquetOutputWriter(path: String, context: TaskAttemptCon }.getRecordWriter(context) } - override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") - - override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) + override def write(row: InternalRow): Unit = recordWriter.write(null, row) override def close(): Unit = recordWriter.close(context) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 897e535953..6f6e301686 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -132,9 +132,7 @@ class TextOutputWriter( private val writer = CodecStreams.createOutputStream(context, new Path(path)) - override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") - - override protected[sql] def writeInternal(row: InternalRow): Unit = { + override def write(row: InternalRow): Unit = { if (!row.isNullAt(0)) { val utf8string = row.getUTF8String(0) utf8string.writeTo(writer) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 0a7631f782..f496c01ce9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -239,10 +239,7 @@ private[orc] class OrcOutputWriter( ).asInstanceOf[RecordWriter[NullWritable, Writable]] } - override def write(row: Row): Unit = - throw new UnsupportedOperationException("call writeInternal") - - override protected[sql] def writeInternal(row: InternalRow): Unit = { + override def write(row: InternalRow): Unit = { recordWriter.write(NullWritable.get(), serializer.serialize(row)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala index abc7c8cc4d..7501334f94 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.sources import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.TaskContext -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory} import org.apache.spark.sql.types.StructType @@ -42,14 +43,14 @@ class CommitFailureTestSource extends SimpleTextSource { path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new SimpleTextOutputWriter(path, context) { + new SimpleTextOutputWriter(path, dataSchema, context) { var failed = false TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) => failed = true SimpleTextRelation.callbackCalled = true } - override def write(row: Row): Unit = { + override def write(row: InternalRow): Unit = { if (SimpleTextRelation.failWriter) { sys.error("Intentional task writer failure for testing purpose.") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 5fdf615259..1607c97cd6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -50,7 +50,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new SimpleTextOutputWriter(path, context) + new SimpleTextOutputWriter(path, dataSchema, context) } override def getFileExtension(context: TaskAttemptContext): String = "" @@ -117,13 +117,13 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { } } -class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) +class SimpleTextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path)) - override def write(row: Row): Unit = { - val serialized = row.toSeq.map { v => + override def write(row: InternalRow): Unit = { + val serialized = row.toSeq(dataSchema).map { v => if (v == null) "" else v.toString }.mkString(",") |