aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/readwriter.py55
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala46
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala46
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala55
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala36
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala29
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala20
10 files changed, 301 insertions, 43 deletions
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 7f5368d8bd..438662bb15 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -454,7 +454,7 @@ class DataFrameWriter(object):
self._jwrite.saveAsTable(name)
@since(1.4)
- def json(self, path, mode=None):
+ def json(self, path, mode=None, compression=None):
"""Saves the content of the :class:`DataFrame` in JSON format at the specified path.
:param path: the path in any Hadoop supported file system
@@ -464,18 +464,19 @@ class DataFrameWriter(object):
* ``overwrite``: Overwrite existing data.
* ``ignore``: Silently ignore this operation if data already exists.
* ``error`` (default case): Throw an exception if data already exists.
-
- You can set the following JSON-specific option(s) for writing JSON files:
- * ``compression`` (default ``None``): compression codec to use when saving to file.
- This can be one of the known case-insensitive shorten names
- (``bzip2``, ``gzip``, ``lz4``, and ``snappy``).
+ :param compression: compression codec to use when saving to file. This can be one of the
+ known case-insensitive shorten names (none, bzip2, gzip, lz4,
+ snappy and deflate).
>>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
"""
- self.mode(mode)._jwrite.json(path)
+ self.mode(mode)
+ if compression is not None:
+ self.option("compression", compression)
+ self._jwrite.json(path)
@since(1.4)
- def parquet(self, path, mode=None, partitionBy=None):
+ def parquet(self, path, mode=None, partitionBy=None, compression=None):
"""Saves the content of the :class:`DataFrame` in Parquet format at the specified path.
:param path: the path in any Hadoop supported file system
@@ -486,32 +487,37 @@ class DataFrameWriter(object):
* ``ignore``: Silently ignore this operation if data already exists.
* ``error`` (default case): Throw an exception if data already exists.
:param partitionBy: names of partitioning columns
+ :param compression: compression codec to use when saving to file. This can be one of the
+ known case-insensitive shorten names (none, snappy, gzip, and lzo).
+ This will overwrite ``spark.sql.parquet.compression.codec``.
>>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data'))
"""
self.mode(mode)
if partitionBy is not None:
self.partitionBy(partitionBy)
+ if compression is not None:
+ self.option("compression", compression)
self._jwrite.parquet(path)
@since(1.6)
- def text(self, path):
+ def text(self, path, compression=None):
"""Saves the content of the DataFrame in a text file at the specified path.
:param path: the path in any Hadoop supported file system
+ :param compression: compression codec to use when saving to file. This can be one of the
+ known case-insensitive shorten names (none, bzip2, gzip, lz4,
+ snappy and deflate).
The DataFrame must have only one column that is of string type.
Each row becomes a new line in the output file.
-
- You can set the following option(s) for writing text files:
- * ``compression`` (default ``None``): compression codec to use when saving to file.
- This can be one of the known case-insensitive shorten names
- (``bzip2``, ``gzip``, ``lz4``, and ``snappy``).
"""
+ if compression is not None:
+ self.option("compression", compression)
self._jwrite.text(path)
@since(2.0)
- def csv(self, path, mode=None):
+ def csv(self, path, mode=None, compression=None):
"""Saves the content of the [[DataFrame]] in CSV format at the specified path.
:param path: the path in any Hadoop supported file system
@@ -522,17 +528,19 @@ class DataFrameWriter(object):
* ``ignore``: Silently ignore this operation if data already exists.
* ``error`` (default case): Throw an exception if data already exists.
- You can set the following CSV-specific option(s) for writing CSV files:
- * ``compression`` (default ``None``): compression codec to use when saving to file.
- This can be one of the known case-insensitive shorten names
- (``bzip2``, ``gzip``, ``lz4``, and ``snappy``).
+ :param compression: compression codec to use when saving to file. This can be one of the
+ known case-insensitive shorten names (none, bzip2, gzip, lz4,
+ snappy and deflate).
>>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data'))
"""
- self.mode(mode)._jwrite.csv(path)
+ self.mode(mode)
+ if compression is not None:
+ self.option("compression", compression)
+ self._jwrite.csv(path)
@since(1.5)
- def orc(self, path, mode=None, partitionBy=None):
+ def orc(self, path, mode=None, partitionBy=None, compression=None):
"""Saves the content of the :class:`DataFrame` in ORC format at the specified path.
::Note: Currently ORC support is only available together with
@@ -546,6 +554,9 @@ class DataFrameWriter(object):
* ``ignore``: Silently ignore this operation if data already exists.
* ``error`` (default case): Throw an exception if data already exists.
:param partitionBy: names of partitioning columns
+ :param compression: compression codec to use when saving to file. This can be one of the
+ known case-insensitive shorten names (none, snappy, zlib, and lzo).
+ This will overwrite ``orc.compress``.
>>> orc_df = hiveContext.read.orc('python/test_support/sql/orc_partitioned')
>>> orc_df.write.orc(os.path.join(tempfile.mkdtemp(), 'data'))
@@ -553,6 +564,8 @@ class DataFrameWriter(object):
self.mode(mode)
if partitionBy is not None:
self.partitionBy(partitionBy)
+ if compression is not None:
+ self.option("compression", compression)
self._jwrite.orc(path)
@since(1.4)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 118e0e9ace..c373606a2e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -455,7 +455,8 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* You can set the following JSON-specific option(s) for writing JSON files:
* <li>`compression` (default `null`): compression codec to use when saving to file. This can be
- * one of the known case-insensitive shorten names (`bzip2`, `gzip`, `lz4`, and `snappy`). </li>
+ * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`,
+ * `snappy` and `deflate`). </li>
*
* @since 1.4.0
*/
@@ -468,6 +469,11 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* format("parquet").save(path)
* }}}
*
+ * You can set the following Parquet-specific option(s) for writing Parquet files:
+ * <li>`compression` (default `null`): compression codec to use when saving to file. This can be
+ * one of the known case-insensitive shorten names(`none`, `snappy`, `gzip`, and `lzo`).
+ * This will overwrite `spark.sql.parquet.compression.codec`. </li>
+ *
* @since 1.4.0
*/
def parquet(path: String): Unit = format("parquet").save(path)
@@ -479,6 +485,11 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* format("orc").save(path)
* }}}
*
+ * You can set the following ORC-specific option(s) for writing ORC files:
+ * <li>`compression` (default `null`): compression codec to use when saving to file. This can be
+ * one of the known case-insensitive shorten names(`none`, `snappy`, `zlib`, and `lzo`).
+ * This will overwrite `orc.compress`. </li>
+ *
* @since 1.5.0
* @note Currently, this method can only be used together with `HiveContext`.
*/
@@ -498,7 +509,8 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* You can set the following option(s) for writing text files:
* <li>`compression` (default `null`): compression codec to use when saving to file. This can be
- * one of the known case-insensitive shorten names (`bzip2`, `gzip`, `lz4`, and `snappy`). </li>
+ * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`,
+ * `snappy` and `deflate`). </li>
*
* @since 1.6.0
*/
@@ -513,7 +525,8 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* You can set the following CSV-specific option(s) for writing CSV files:
* <li>`compression` (default `null`): compression codec to use when saving to file. This can be
- * one of the known case-insensitive shorten names (`bzip2`, `gzip`, `lz4`, and `snappy`). </li>
+ * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`,
+ * `snappy` and `deflate`). </li>
*
* @since 2.0.0
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala
index 9e913de48f..032ba61d9d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala
@@ -25,6 +25,8 @@ import org.apache.spark.util.Utils
private[datasources] object CompressionCodecs {
private val shortCompressionCodecNames = Map(
+ "none" -> null,
+ "uncompressed" -> null,
"bzip2" -> classOf[BZip2Codec].getName,
"deflate" -> classOf[DeflateCodec].getName,
"gzip" -> classOf[GzipCodec].getName,
@@ -39,7 +41,9 @@ private[datasources] object CompressionCodecs {
val codecName = shortCompressionCodecNames.getOrElse(name.toLowerCase, name)
try {
// Validate the codec name
- Utils.classForName(codecName)
+ if (codecName != null) {
+ Utils.classForName(codecName)
+ }
codecName
} catch {
case e: ClassNotFoundException =>
@@ -53,10 +57,16 @@ private[datasources] object CompressionCodecs {
* `codec` should be a full class path
*/
def setCodecConfiguration(conf: Configuration, codec: String): Unit = {
- conf.set("mapreduce.output.fileoutputformat.compress", "true")
- conf.set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString)
- conf.set("mapreduce.output.fileoutputformat.compress.codec", codec)
- conf.set("mapreduce.map.output.compress", "true")
- conf.set("mapreduce.map.output.compress.codec", codec)
+ if (codec != null){
+ conf.set("mapreduce.output.fileoutputformat.compress", "true")
+ conf.set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString)
+ conf.set("mapreduce.output.fileoutputformat.compress.codec", codec)
+ conf.set("mapreduce.map.output.compress", "true")
+ conf.set("mapreduce.map.output.compress.codec", codec)
+ } else {
+ // This infers the option `compression` is set to `uncompressed` or `none`.
+ conf.set("mapreduce.output.fileoutputformat.compress", "false")
+ conf.set("mapreduce.map.output.compress", "false")
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
index a1806221b7..7ea098c72b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
@@ -148,6 +148,19 @@ private[sql] class ParquetRelation(
.get(ParquetRelation.METASTORE_SCHEMA)
.map(DataType.fromJson(_).asInstanceOf[StructType])
+ private val compressionCodec: Option[String] = parameters
+ .get("compression")
+ .map { codecName =>
+ // Validate if given compression codec is supported or not.
+ val shortParquetCompressionCodecNames = ParquetRelation.shortParquetCompressionCodecNames
+ if (!shortParquetCompressionCodecNames.contains(codecName.toLowerCase)) {
+ val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase)
+ throw new IllegalArgumentException(s"Codec [$codecName] " +
+ s"is not available. Available codecs are ${availableCodecs.mkString(", ")}.")
+ }
+ codecName.toLowerCase
+ }
+
private lazy val metadataCache: MetadataCache = {
val meta = new MetadataCache
meta.refresh()
@@ -286,7 +299,8 @@ private[sql] class ParquetRelation(
ParquetRelation
.shortParquetCompressionCodecNames
.getOrElse(
- sqlContext.conf.parquetCompressionCodec.toLowerCase(),
+ compressionCodec
+ .getOrElse(sqlContext.conf.parquetCompressionCodec.toLowerCase),
CompressionCodecName.UNCOMPRESSED).name())
new BucketedOutputWriterFactory {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 3ecbb14f2e..9cd3a9ab95 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -21,6 +21,12 @@ import java.io.File
import java.nio.charset.UnsupportedCharsetException
import java.sql.Timestamp
+import scala.collection.JavaConverters._
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.io.SequenceFile.CompressionType
+import org.apache.hadoop.io.compress.GzipCodec
+
import org.apache.spark.SparkException
import org.apache.spark.sql.{DataFrame, QueryTest, Row}
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
@@ -396,6 +402,46 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
}
+ test("SPARK-13543 Write the output as uncompressed via option()") {
+ val clonedConf = new Configuration(hadoopConfiguration)
+ hadoopConfiguration.set("mapreduce.output.fileoutputformat.compress", "true")
+ hadoopConfiguration
+ .set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString)
+ hadoopConfiguration
+ .set("mapreduce.output.fileoutputformat.compress.codec", classOf[GzipCodec].getName)
+ hadoopConfiguration.set("mapreduce.map.output.compress", "true")
+ hadoopConfiguration.set("mapreduce.map.output.compress.codec", classOf[GzipCodec].getName)
+ withTempDir { dir =>
+ try {
+ val csvDir = new File(dir, "csv").getCanonicalPath
+ val cars = sqlContext.read
+ .format("csv")
+ .option("header", "true")
+ .load(testFile(carsFile))
+
+ cars.coalesce(1).write
+ .format("csv")
+ .option("header", "true")
+ .option("compression", "none")
+ .save(csvDir)
+
+ val compressedFiles = new File(csvDir).listFiles()
+ assert(compressedFiles.exists(!_.getName.endsWith(".gz")))
+
+ val carsCopy = sqlContext.read
+ .format("csv")
+ .option("header", "true")
+ .load(csvDir)
+
+ verifyCars(carsCopy, withHeader = true)
+ } finally {
+ // Hadoop 1 doesn't have `Configuration.unset`
+ hadoopConfiguration.clear()
+ clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue))
+ }
+ }
+ }
+
test("Schema inference correctly identifies the datatype when data is sparse.") {
val df = sqlContext.read
.format("csv")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index c7f33e1746..3a33554143 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -23,9 +23,10 @@ import java.sql.{Date, Timestamp}
import scala.collection.JavaConverters._
import com.fasterxml.jackson.core.JsonFactory
-import org.apache.commons.io.FileUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, PathFilter}
+import org.apache.hadoop.io.SequenceFile.CompressionType
+import org.apache.hadoop.io.compress.GzipCodec
import org.scalactic.Tolerance._
import org.apache.spark.rdd.RDD
@@ -1524,6 +1525,49 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
}
+ test("SPARK-13543 Write the output as uncompressed via option()") {
+ val clonedConf = new Configuration(hadoopConfiguration)
+ hadoopConfiguration.set("mapreduce.output.fileoutputformat.compress", "true")
+ hadoopConfiguration
+ .set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString)
+ hadoopConfiguration
+ .set("mapreduce.output.fileoutputformat.compress.codec", classOf[GzipCodec].getName)
+ hadoopConfiguration.set("mapreduce.map.output.compress", "true")
+ hadoopConfiguration.set("mapreduce.map.output.compress.codec", classOf[GzipCodec].getName)
+ withTempDir { dir =>
+ try {
+ val dir = Utils.createTempDir()
+ dir.delete()
+
+ val path = dir.getCanonicalPath
+ primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path)
+
+ val jsonDF = sqlContext.read.json(path)
+ val jsonDir = new File(dir, "json").getCanonicalPath
+ jsonDF.coalesce(1).write
+ .format("json")
+ .option("compression", "none")
+ .save(jsonDir)
+
+ val compressedFiles = new File(jsonDir).listFiles()
+ assert(compressedFiles.exists(!_.getName.endsWith(".gz")))
+
+ val jsonCopy = sqlContext.read
+ .format("json")
+ .load(jsonDir)
+
+ assert(jsonCopy.count == jsonDF.count)
+ val jsonCopySome = jsonCopy.selectExpr("string", "long", "boolean")
+ val jsonDFSome = jsonDF.selectExpr("string", "long", "boolean")
+ checkAnswer(jsonCopySome, jsonDFSome)
+ } finally {
+ // Hadoop 1 doesn't have `Configuration.unset`
+ hadoopConfiguration.clear()
+ clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue))
+ }
+ }
+ }
+
test("Casting long as timestamp") {
withTempTable("jsonTable") {
val schema = (new StructType).add("ts", TimestampType)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala
index 59e0e6a7cf..9eb1016b64 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala
@@ -17,6 +17,14 @@
package org.apache.spark.sql.execution.datasources.text
+import java.io.File
+
+import scala.collection.JavaConverters._
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.io.SequenceFile.CompressionType
+import org.apache.hadoop.io.compress.GzipCodec
+
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructType}
@@ -59,20 +67,49 @@ class TextSuite extends QueryTest with SharedSQLContext {
test("SPARK-13503 Support to specify the option for compression codec for TEXT") {
val testDf = sqlContext.read.text(testFile)
-
- Seq("bzip2", "deflate", "gzip").foreach { codecName =>
- val tempDir = Utils.createTempDir()
- val tempDirPath = tempDir.getAbsolutePath()
- testDf.write.option("compression", codecName).mode(SaveMode.Overwrite).text(tempDirPath)
- verifyFrame(sqlContext.read.text(tempDirPath))
+ val extensionNameMap = Map("bzip2" -> ".bz2", "deflate" -> ".deflate", "gzip" -> ".gz")
+ extensionNameMap.foreach {
+ case (codecName, extension) =>
+ val tempDir = Utils.createTempDir()
+ val tempDirPath = tempDir.getAbsolutePath
+ testDf.write.option("compression", codecName).mode(SaveMode.Overwrite).text(tempDirPath)
+ val compressedFiles = new File(tempDirPath).listFiles()
+ assert(compressedFiles.exists(_.getName.endsWith(extension)))
+ verifyFrame(sqlContext.read.text(tempDirPath))
}
val errMsg = intercept[IllegalArgumentException] {
- val tempDirPath = Utils.createTempDir().getAbsolutePath()
+ val tempDirPath = Utils.createTempDir().getAbsolutePath
testDf.write.option("compression", "illegal").mode(SaveMode.Overwrite).text(tempDirPath)
}
- assert(errMsg.getMessage === "Codec [illegal] is not available. " +
- "Known codecs are bzip2, deflate, lz4, gzip, snappy.")
+ assert(errMsg.getMessage.contains("Codec [illegal] is not available. " +
+ "Known codecs are"))
+ }
+
+ test("SPARK-13543 Write the output as uncompressed via option()") {
+ val clonedConf = new Configuration(hadoopConfiguration)
+ hadoopConfiguration.set("mapreduce.output.fileoutputformat.compress", "true")
+ hadoopConfiguration
+ .set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString)
+ hadoopConfiguration
+ .set("mapreduce.output.fileoutputformat.compress.codec", classOf[GzipCodec].getName)
+ hadoopConfiguration.set("mapreduce.map.output.compress", "true")
+ hadoopConfiguration.set("mapreduce.map.output.compress.codec", classOf[GzipCodec].getName)
+ withTempDir { dir =>
+ try {
+ val testDf = sqlContext.read.text(testFile)
+ val tempDir = Utils.createTempDir()
+ val tempDirPath = tempDir.getAbsolutePath
+ testDf.write.option("compression", "none").mode(SaveMode.Overwrite).text(tempDirPath)
+ val compressedFiles = new File(tempDirPath).listFiles()
+ assert(compressedFiles.exists(!_.getName.endsWith(".gz")))
+ verifyFrame(sqlContext.read.text(tempDirPath))
+ } finally {
+ // Hadoop 1 doesn't have `Configuration.unset`
+ hadoopConfiguration.clear()
+ clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue))
+ }
+ }
}
private def testFile: String = {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
index 800823feba..2b06e1a12c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
@@ -23,7 +23,8 @@ import com.google.common.base.Objects
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
-import org.apache.hadoop.hive.ql.io.orc.{OrcInputFormat, OrcOutputFormat, OrcSerde, OrcSplit, OrcStruct}
+import org.apache.hadoop.hive.ql.io.orc._
+import org.apache.hadoop.hive.ql.io.orc.OrcFile.OrcTableProperties
import org.apache.hadoop.hive.serde2.objectinspector.SettableStructObjectInspector
import org.apache.hadoop.hive.serde2.typeinfo.{StructTypeInfo, TypeInfoUtils}
import org.apache.hadoop.io.{NullWritable, Writable}
@@ -162,6 +163,19 @@ private[sql] class OrcRelation(
extends HadoopFsRelation(maybePartitionSpec, parameters)
with Logging {
+ private val compressionCodec: Option[String] = parameters
+ .get("compression")
+ .map { codecName =>
+ // Validate if given compression codec is supported or not.
+ val shortOrcCompressionCodecNames = OrcRelation.shortOrcCompressionCodecNames
+ if (!shortOrcCompressionCodecNames.contains(codecName.toLowerCase)) {
+ val availableCodecs = shortOrcCompressionCodecNames.keys.map(_.toLowerCase)
+ throw new IllegalArgumentException(s"Codec [$codecName] " +
+ s"is not available. Available codecs are ${availableCodecs.mkString(", ")}.")
+ }
+ codecName.toLowerCase
+ }
+
private[sql] def this(
paths: Array[String],
maybeDataSchema: Option[StructType],
@@ -211,6 +225,15 @@ private[sql] class OrcRelation(
}
override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = {
+ // Sets compression scheme
+ compressionCodec.foreach { codecName =>
+ job.getConfiguration.set(
+ OrcTableProperties.COMPRESSION.getPropName,
+ OrcRelation
+ .shortOrcCompressionCodecNames
+ .getOrElse(codecName, CompressionKind.NONE).name())
+ }
+
job.getConfiguration match {
case conf: JobConf =>
conf.setOutputFormat(classOf[OrcOutputFormat])
@@ -337,3 +360,14 @@ private[orc] object OrcTableScan {
// This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public.
private[orc] val SARG_PUSHDOWN = "sarg.pushdown"
}
+
+private[orc] object OrcRelation {
+ // The ORC compression short names
+ val shortOrcCompressionCodecNames = Map(
+ "none" -> CompressionKind.NONE,
+ "uncompressed" -> CompressionKind.NONE,
+ "snappy" -> CompressionKind.SNAPPY,
+ "zlib" -> CompressionKind.ZLIB,
+ "lzo" -> CompressionKind.LZO)
+}
+
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala
index 528f40b002..823b52af1b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala
@@ -17,7 +17,10 @@
package org.apache.spark.sql.hive.orc
-import org.apache.hadoop.fs.Path
+import java.io.File
+
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.hive.ql.io.orc.{CompressionKind, OrcFile}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql.Row
@@ -81,4 +84,28 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest {
}
}
}
+
+ test("SPARK-13543: Support for specifying compression codec for ORC via option()") {
+ withTempPath { dir =>
+ val path = s"${dir.getCanonicalPath}/table1"
+ val df = (1 to 5).map(i => (i, (i % 2).toString)).toDF("a", "b")
+ df.write
+ .option("compression", "ZlIb")
+ .orc(path)
+
+ // Check if this is compressed as ZLIB.
+ val conf = sparkContext.hadoopConfiguration
+ val fs = FileSystem.getLocal(conf)
+ val maybeOrcFile = new File(path).listFiles().find(_.getName.endsWith(".orc"))
+ assert(maybeOrcFile.isDefined)
+ val orcFilePath = new Path(maybeOrcFile.get.toPath.toString)
+ val orcReader = OrcFile.createReader(orcFilePath, OrcFile.readerOptions(conf))
+ assert(orcReader.getCompression == CompressionKind.ZLIB)
+
+ val copyDf = sqlContext
+ .read
+ .orc(path)
+ checkAnswer(df, copyDf)
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala
index f2501d7ce3..8856148a95 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala
@@ -208,4 +208,24 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
checkAnswer(loadedDF, df)
}
}
+
+ test("SPARK-13543: Support for specifying compression codec for Parquet via option()") {
+ withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "UNCOMPRESSED") {
+ withTempPath { dir =>
+ val path = s"${dir.getCanonicalPath}/table1"
+ val df = (1 to 5).map(i => (i, (i % 2).toString)).toDF("a", "b")
+ df.write
+ .option("compression", "GzIP")
+ .parquet(path)
+
+ val compressedFiles = new File(path).listFiles()
+ assert(compressedFiles.exists(_.getName.endsWith(".gz.parquet")))
+
+ val copyDf = sqlContext
+ .read
+ .parquet(path)
+ checkAnswer(df, copyDf)
+ }
+ }
+ }
}