aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2017-01-08 00:42:09 +0800
committerWenchen Fan <wenchen@databricks.com>2017-01-08 00:42:09 +0800
commitb3d39620c563e5f6a32a4082aa3908e1009c17d2 (patch)
tree2b927e760096c78895228cb9d79f6730c8e9d182
parentcdda3372a39b508f7d159426749b682476c813b9 (diff)
downloadspark-b3d39620c563e5f6a32a4082aa3908e1009c17d2.tar.gz
spark-b3d39620c563e5f6a32a4082aa3908e1009c17d2.tar.bz2
spark-b3d39620c563e5f6a32a4082aa3908e1009c17d2.zip
[SPARK-19085][SQL] cleanup OutputWriterFactory and OutputWriter
## What changes were proposed in this pull request? `OutputWriterFactory`/`OutputWriter` are internal interfaces and we can remove some unnecessary APIs: 1. `OutputWriterFactory.newWriter(path: String)`: no one calls it and no one implements it. 2. `OutputWriter.write(row: Row)`: during execution we only call `writeInternal`, which is weird as `OutputWriter` is already an internal interface. We should rename `writeInternal` to `write` and remove `def write(row: Row)` and it's related converter code. All implementations should just implement `def write(row: InternalRow)` ## How was this patch tested? existing tests. Author: Wenchen Fan <wenchen@databricks.com> Closes #16479 from cloud-fan/hive-writer.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala26
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala4
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala5
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala7
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala8
13 files changed, 37 insertions, 69 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index b5aa7ce4e1..89bbc1556c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -45,9 +45,12 @@ private[libsvm] class LibSVMOutputWriter(
private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path))
- override def write(row: Row): Unit = {
- val label = row.get(0)
- val vector = row.get(1).asInstanceOf[Vector]
+ // This `asInstanceOf` is safe because it's guaranteed by `LibSVMFileFormat.verifySchema`
+ private val udt = dataSchema(1).dataType.asInstanceOf[VectorUDT]
+
+ override def write(row: InternalRow): Unit = {
+ val label = row.getDouble(0)
+ val vector = udt.deserialize(row.getStruct(1, udt.sqlType.length))
writer.write(label.toString)
vector.foreachActive { case (i, v) =>
writer.write(s" ${i + 1}:$v")
@@ -115,6 +118,7 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
+ verifySchema(dataSchema)
new OutputWriterFactory {
override def newInstance(
path: String,
diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
index 2517de59fe..c701f38238 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
@@ -17,12 +17,12 @@
package org.apache.spark.ml.source.libsvm
-import java.io.File
+import java.io.{File, IOException}
import java.nio.charset.StandardCharsets
import com.google.common.io.Files
-import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Row, SaveMode}
@@ -100,7 +100,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
test("write libsvm data failed due to invalid schema") {
val df = spark.read.format("text").load(path)
- intercept[SparkException] {
+ intercept[IOException] {
df.write.format("libsvm").save(path + "_2")
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 7e23260e65..b7f3559b65 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -466,7 +466,7 @@ case class DataSource(
// SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does
// not need to have the query as child, to avoid to analyze an optimized query,
// because InsertIntoHadoopFsRelationCommand will be optimized first.
- val columns = partitionColumns.map { name =>
+ val partitionAttributes = partitionColumns.map { name =>
val plan = data.logicalPlan
plan.resolve(name :: Nil, data.sparkSession.sessionState.analyzer.resolver).getOrElse {
throw new AnalysisException(
@@ -485,7 +485,7 @@ case class DataSource(
InsertIntoHadoopFsRelationCommand(
outputPath = outputPath,
staticPartitions = Map.empty,
- partitionColumns = columns,
+ partitionColumns = partitionAttributes,
bucketSpec = bucketSpec,
fileFormat = format,
options = options,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 1eb4541e2c..16c5193eda 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -64,18 +64,18 @@ object FileFormatWriter extends Logging {
val outputWriterFactory: OutputWriterFactory,
val allColumns: Seq[Attribute],
val partitionColumns: Seq[Attribute],
- val nonPartitionColumns: Seq[Attribute],
+ val dataColumns: Seq[Attribute],
val bucketSpec: Option[BucketSpec],
val path: String,
val customPartitionLocations: Map[TablePartitionSpec, String],
val maxRecordsPerFile: Long)
extends Serializable {
- assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns),
+ assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns),
s"""
|All columns: ${allColumns.mkString(", ")}
|Partition columns: ${partitionColumns.mkString(", ")}
- |Non-partition columns: ${nonPartitionColumns.mkString(", ")}
+ |Data columns: ${dataColumns.mkString(", ")}
""".stripMargin)
}
@@ -120,7 +120,7 @@ object FileFormatWriter extends Logging {
outputWriterFactory = outputWriterFactory,
allColumns = queryExecution.logical.output,
partitionColumns = partitionColumns,
- nonPartitionColumns = dataColumns,
+ dataColumns = dataColumns,
bucketSpec = bucketSpec,
path = outputSpec.outputPath,
customPartitionLocations = outputSpec.customPartitionLocations,
@@ -246,9 +246,8 @@ object FileFormatWriter extends Logging {
currentWriter = description.outputWriterFactory.newInstance(
path = tmpFilePath,
- dataSchema = description.nonPartitionColumns.toStructType,
+ dataSchema = description.dataColumns.toStructType,
context = taskAttemptContext)
- currentWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType)
}
override def execute(iter: Iterator[InternalRow]): Set[String] = {
@@ -267,7 +266,7 @@ object FileFormatWriter extends Logging {
}
val internalRow = iter.next()
- currentWriter.writeInternal(internalRow)
+ currentWriter.write(internalRow)
recordsInFile += 1
}
releaseResources()
@@ -364,9 +363,8 @@ object FileFormatWriter extends Logging {
currentWriter = description.outputWriterFactory.newInstance(
path = path,
- dataSchema = description.nonPartitionColumns.toStructType,
+ dataSchema = description.dataColumns.toStructType,
context = taskAttemptContext)
- currentWriter.initConverter(description.nonPartitionColumns.toStructType)
}
override def execute(iter: Iterator[InternalRow]): Set[String] = {
@@ -383,7 +381,7 @@ object FileFormatWriter extends Logging {
// Returns the data columns to be written given an input row
val getOutputRow = UnsafeProjection.create(
- description.nonPartitionColumns, description.allColumns)
+ description.dataColumns, description.allColumns)
// Returns the partition path given a partition key.
val getPartitionStringFunc = UnsafeProjection.create(
@@ -392,7 +390,7 @@ object FileFormatWriter extends Logging {
// Sorts the data before write, so that we only need one writer at the same time.
val sorter = new UnsafeKVExternalSorter(
sortingKeySchema,
- StructType.fromAttributes(description.nonPartitionColumns),
+ StructType.fromAttributes(description.dataColumns),
SparkEnv.get.blockManager,
SparkEnv.get.serializerManager,
TaskContext.get().taskMemoryManager().pageSizeBytes,
@@ -448,7 +446,7 @@ object FileFormatWriter extends Logging {
newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
}
- currentWriter.writeInternal(sortedIterator.getValue)
+ currentWriter.write(sortedIterator.getValue)
recordsInFile += 1
}
releaseResources()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
index 84ea58b68a..423009e4ec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
@@ -45,7 +45,7 @@ case class InsertIntoHadoopFsRelationCommand(
bucketSpec: Option[BucketSpec],
fileFormat: FileFormat,
options: Map[String, String],
- @transient query: LogicalPlan,
+ query: LogicalPlan,
mode: SaveMode,
catalogTable: Option[CatalogTable],
fileIndex: Option[FileIndex])
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala
index a73c8146c1..868e537142 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala
@@ -47,19 +47,6 @@ abstract class OutputWriterFactory extends Serializable {
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter
-
- /**
- * Returns a new instance of [[OutputWriter]] that will write data to the given path.
- * This method gets called by each task on executor to write InternalRows to
- * format-specific files. Compared to the other `newInstance()`, this is a newer API that
- * passes only the path that the writer must write to. The writer must write to the exact path
- * and not modify it (do not add subdirectories, extensions, etc.). All other
- * file-format-specific information needed to create the writer must be passed
- * through the [[OutputWriterFactory]] implementation.
- */
- def newWriter(path: String): OutputWriter = {
- throw new UnsupportedOperationException("newInstance with just path not supported")
- }
}
@@ -74,22 +61,11 @@ abstract class OutputWriter {
* Persists a single row. Invoked on the executor side. When writing to dynamically partitioned
* tables, dynamic partition columns are not included in rows to be written.
*/
- def write(row: Row): Unit
+ def write(row: InternalRow): Unit
/**
* Closes the [[OutputWriter]]. Invoked on the executor side after all rows are persisted, before
* the task output is committed.
*/
def close(): Unit
-
- private var converter: InternalRow => Row = _
-
- protected[sql] def initConverter(dataSchema: StructType) = {
- converter =
- CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row]
- }
-
- protected[sql] def writeInternal(row: InternalRow): Unit = {
- write(converter(row))
- }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
index 23c07eb630..8c19be48c2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
@@ -221,9 +221,7 @@ private[csv] class CsvOutputWriter(
row.get(ordinal, dt).toString
}
- override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")
-
- override protected[sql] def writeInternal(row: InternalRow): Unit = {
+ override def write(row: InternalRow): Unit = {
csvWriter.writeRow(rowToString(row), printHeader)
printHeader = false
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
index a9d8ddfe9d..be1f94dbad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
@@ -159,9 +159,7 @@ private[json] class JsonOutputWriter(
// create the Generator without separator inserted between 2 records
private[this] val gen = new JacksonGenerator(dataSchema, writer, options)
- override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")
-
- override protected[sql] def writeInternal(row: InternalRow): Unit = {
+ override def write(row: InternalRow): Unit = {
gen.write(row)
gen.writeLineEnding()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala
index 5c0f8af17a..8361762b09 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala
@@ -37,9 +37,7 @@ private[parquet] class ParquetOutputWriter(path: String, context: TaskAttemptCon
}.getRecordWriter(context)
}
- override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")
-
- override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row)
+ override def write(row: InternalRow): Unit = recordWriter.write(null, row)
override def close(): Unit = recordWriter.close(context)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
index 897e535953..6f6e301686 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
@@ -132,9 +132,7 @@ class TextOutputWriter(
private val writer = CodecStreams.createOutputStream(context, new Path(path))
- override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")
-
- override protected[sql] def writeInternal(row: InternalRow): Unit = {
+ override def write(row: InternalRow): Unit = {
if (!row.isNullAt(0)) {
val utf8string = row.getUTF8String(0)
utf8string.writeTo(writer)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
index 0a7631f782..f496c01ce9 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
@@ -239,10 +239,7 @@ private[orc] class OrcOutputWriter(
).asInstanceOf[RecordWriter[NullWritable, Writable]]
}
- override def write(row: Row): Unit =
- throw new UnsupportedOperationException("call writeInternal")
-
- override protected[sql] def writeInternal(row: InternalRow): Unit = {
+ override def write(row: InternalRow): Unit = {
recordWriter.write(NullWritable.get(), serializer.serialize(row))
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala
index abc7c8cc4d..7501334f94 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala
@@ -20,7 +20,8 @@ package org.apache.spark.sql.sources
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
import org.apache.spark.TaskContext
-import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory}
import org.apache.spark.sql.types.StructType
@@ -42,14 +43,14 @@ class CommitFailureTestSource extends SimpleTextSource {
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
- new SimpleTextOutputWriter(path, context) {
+ new SimpleTextOutputWriter(path, dataSchema, context) {
var failed = false
TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) =>
failed = true
SimpleTextRelation.callbackCalled = true
}
- override def write(row: Row): Unit = {
+ override def write(row: InternalRow): Unit = {
if (SimpleTextRelation.failWriter) {
sys.error("Intentional task writer failure for testing purpose.")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
index 5fdf615259..1607c97cd6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
@@ -50,7 +50,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister {
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
- new SimpleTextOutputWriter(path, context)
+ new SimpleTextOutputWriter(path, dataSchema, context)
}
override def getFileExtension(context: TaskAttemptContext): String = ""
@@ -117,13 +117,13 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister {
}
}
-class SimpleTextOutputWriter(path: String, context: TaskAttemptContext)
+class SimpleTextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemptContext)
extends OutputWriter {
private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path))
- override def write(row: Row): Unit = {
- val serialized = row.toSeq.map { v =>
+ override def write(row: InternalRow): Unit = {
+ val serialized = row.toSeq(dataSchema).map { v =>
if (v == null) "" else v.toString
}.mkString(",")