aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-05-23 10:48:25 -0700
committerCheng Lian <lian@databricks.com>2016-05-23 10:48:25 -0700
commit80091b8a6840b562cf76341926e5b828d4def7e2 (patch)
treeffd8d5e84a62d98b787cd8ace6bb09fc5f7f2a64
parentdafcb05c2ef8e09f45edfb7eabf58116c23975a0 (diff)
downloadspark-80091b8a6840b562cf76341926e5b828d4def7e2.tar.gz
spark-80091b8a6840b562cf76341926e5b828d4def7e2.tar.bz2
spark-80091b8a6840b562cf76341926e5b828d4def7e2.zip
[SPARK-14031][SQL] speedup CSV writer
## What changes were proposed in this pull request? Currently, we create an CSVWriter for every row, it's very expensive and memory hungry, took about 15 seconds to write out 1 mm rows (two columns). This PR will write the rows in batch mode, create a CSVWriter for every 1k rows, which could write out 1 mm rows in about 1 seconds (15X faster). ## How was this patch tested? Manually benchmark it. Author: Davies Liu <davies@databricks.com> Closes #13229 from davies/csv_writer.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala23
2 files changed, 29 insertions, 13 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala
index ae797a1e07..111995da9c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala
@@ -76,17 +76,26 @@ private[sql] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) exten
writerSettings.setQuoteAllFields(false)
writerSettings.setHeaders(headers: _*)
- def writeRow(row: Seq[String], includeHeader: Boolean): String = {
- val buffer = new ByteArrayOutputStream()
- val outputWriter = new OutputStreamWriter(buffer, StandardCharsets.UTF_8)
- val writer = new CsvWriter(outputWriter, writerSettings)
+ private var buffer = new ByteArrayOutputStream()
+ private var writer = new CsvWriter(
+ new OutputStreamWriter(buffer, StandardCharsets.UTF_8),
+ writerSettings)
+ def writeRow(row: Seq[String], includeHeader: Boolean): Unit = {
if (includeHeader) {
writer.writeHeaders()
}
writer.writeRow(row.toArray: _*)
+ }
+
+ def flush(): String = {
writer.close()
- buffer.toString.stripLineEnd
+ val lines = buffer.toString.stripLineEnd
+ buffer = new ByteArrayOutputStream()
+ writer = new CsvWriter(
+ new OutputStreamWriter(buffer, StandardCharsets.UTF_8),
+ writerSettings)
+ lines
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
index 4f2d4387b1..9849484dce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
@@ -176,8 +176,8 @@ private[sql] class CsvOutputWriter(
}.getRecordWriter(context)
}
- private var firstRow: Boolean = params.headerFlag
-
+ private val FLUSH_BATCH_SIZE = 1024L
+ private var records: Long = 0L
private val csvWriter = new LineCsvWriter(params, dataSchema.fieldNames.toSeq)
private def rowToString(row: Seq[Any]): Seq[String] = row.map { field =>
@@ -191,16 +191,23 @@ private[sql] class CsvOutputWriter(
override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")
override protected[sql] def writeInternal(row: InternalRow): Unit = {
- // TODO: Instead of converting and writing every row, we should use the univocity buffer
- val resultString = csvWriter.writeRow(rowToString(row.toSeq(dataSchema)), firstRow)
- if (firstRow) {
- firstRow = false
+ csvWriter.writeRow(rowToString(row.toSeq(dataSchema)), records == 0L && params.headerFlag)
+ records += 1
+ if (records % FLUSH_BATCH_SIZE == 0) {
+ flush()
+ }
+ }
+
+ private def flush(): Unit = {
+ val lines = csvWriter.flush()
+ if (lines.nonEmpty) {
+ text.set(lines)
+ recordWriter.write(NullWritable.get(), text)
}
- text.set(resultString)
- recordWriter.write(NullWritable.get(), text)
}
override def close(): Unit = {
+ flush()
recordWriter.close(context)
}
}