diff options
11 files changed, 115 insertions, 13 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 7629369ab1..b5b2a681e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -112,7 +112,7 @@ private[libsvm] class LibSVMOutputWriter( */ // If this is moved or renamed, please update DataSource's backwardCompatibilityMap. @Since("1.6.0") -class LibSVMFileFormat extends FileFormat with DataSourceRegister { +class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister { @Since("1.6.0") override def shortName(): String = "libsvm" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 7503285ee2..13a86bfb38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -151,11 +151,18 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { val splitFiles = selectedPartitions.flatMap { partition => partition.files.flatMap { file => val blockLocations = getBlockLocations(file) - (0L until file.getLen by maxSplitBytes).map { offset => - val remaining = file.getLen - offset - val size = if (remaining > maxSplitBytes) maxSplitBytes else remaining - val hosts = getBlockHosts(blockLocations, offset, size) - PartitionedFile(partition.values, file.getPath.toUri.toString, offset, size, hosts) + if (files.fileFormat.isSplitable(files.sparkSession, files.options, file.getPath)) { + (0L until file.getLen by maxSplitBytes).map { offset => + val remaining = file.getLen - offset + val size = if (remaining > maxSplitBytes) maxSplitBytes else remaining + val hosts = getBlockHosts(blockLocations, offset, size) + PartitionedFile( + partition.values, file.getPath.toUri.toString, offset, size, hosts) + } + } else { + val hosts = getBlockHosts(blockLocations, 0, file.getLen) + Seq(PartitionedFile( + partition.values, file.getPath.toUri.toString, 0, file.getLen, hosts)) } } }.toArray.sortBy(_.length)(implicitly[Ordering[Long]].reverse) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 4d36b76056..be52de8e40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -36,7 +36,7 @@ import org.apache.spark.util.SerializableConfiguration /** * Provides access to CSV data from pure SQL statements. */ -class CSVFileFormat extends FileFormat with DataSourceRegister { +class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { override def shortName(): String = "csv" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala index 7f3eed3fb1..890e64db59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala @@ -22,6 +22,7 @@ import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ +import org.apache.hadoop.io.compress.{CompressionCodecFactory, SplittableCompressionCodec} import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} @@ -29,12 +30,12 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.FileRelation import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, Filter} -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration /** @@ -215,6 +216,16 @@ trait FileFormat { } /** + * Returns whether a file with `path` could be splitted or not. + */ + def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = { + false + } + + /** * Returns a function that can be used to read a single file in as an Iterator of InternalRow. * * @param dataSchema The global data schema. It can be either specified by the user, or @@ -298,6 +309,24 @@ trait FileFormat { } /** + * The base class file format that is based on text file. + */ +abstract class TextBasedFileFormat extends FileFormat { + private var codecFactory: CompressionCodecFactory = null + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = { + if (codecFactory == null) { + codecFactory = new CompressionCodecFactory( + sparkSession.sessionState.newHadoopConfWithOptions(options)) + } + val codec = codecFactory.getCodec(path) + codec == null || codec.isInstanceOf[SplittableCompressionCodec] + } +} + +/** * A collection of data files from a partitioned relation, along with the partition values in the * form of an [[InternalRow]]. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index c7c5281196..86aef1f7d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration -class JsonFileFormat extends FileFormat with DataSourceRegister { +class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { override def shortName(): String = "json" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index ada9cd4b8e..3735c94968 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -261,6 +261,13 @@ private[sql] class ParquetFileFormat schema.forall(_.dataType.isInstanceOf[AtomicType]) } + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = { + true + } + override private[sql] def buildReaderWithPartitionValues( sparkSession: SparkSession, dataSchema: StructType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 9c03ab28dd..abb6059f75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -36,7 +36,7 @@ import org.apache.spark.util.SerializableConfiguration /** * A data source for reading text files. */ -class TextFileFormat extends FileFormat with DataSourceRegister { +class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { override def shortName(): String = "text" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 25f1443e70..67ff257b93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -340,6 +340,41 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi } } + test("SPARK-15654 do not split non-splittable files") { + // Check if a non-splittable file is not assigned into partitions + Seq("gz", "snappy", "lz4").map { suffix => + val table = createTable( + files = Seq(s"file1.${suffix}" -> 3, s"file2.${suffix}" -> 1, s"file3.${suffix}" -> 1) + ) + withSQLConf( + SQLConf.FILES_MAX_PARTITION_BYTES.key -> "2", + SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "0") { + checkScan(table.select('c1)) { partitions => + assert(partitions.size == 2) + assert(partitions(0).files.size == 1) + assert(partitions(1).files.size == 2) + } + } + } + + // Check if a splittable compressed file is assigned into multiple partitions + Seq("bz2").map { suffix => + val table = createTable( + files = Seq(s"file1.${suffix}" -> 3, s"file2.${suffix}" -> 1, s"file3.${suffix}" -> 1) + ) + withSQLConf( + SQLConf.FILES_MAX_PARTITION_BYTES.key -> "2", + SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "0") { + checkScan(table.select('c1)) { partitions => + assert(partitions.size == 3) + assert(partitions(0).files.size == 1) + assert(partitions(1).files.size == 2) + assert(partitions(2).files.size == 1) + } + } + } + } + // Helpers for checking the arguments passed to the FileFormat. protected val checkPartitionSchema = @@ -434,7 +469,7 @@ object LastArguments { } /** A test [[FileFormat]] that records the arguments passed to buildReader, and returns nothing. */ -class TestFileFormat extends FileFormat { +class TestFileFormat extends TextBasedFileFormat { override def toString: String = "TestFileFormat" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index 7b6981f95e..5695f6af7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.Utils @@ -137,6 +138,22 @@ class TextSuite extends QueryTest with SharedSQLContext { } } + test("SPARK-15654: should not split gz files") { + withTempDir { dir => + val path = dir.getCanonicalPath + val df1 = spark.range(0, 1000).selectExpr("CAST(id AS STRING) AS s") + df1.write.option("compression", "gzip").mode("overwrite").text(path) + + val expected = df1.collect() + Seq(10, 100, 1000).foreach { bytes => + withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> bytes.toString) { + val df2 = spark.read.format("text").load(path) + checkAnswer(df2, expected) + } + } + } + } + private def testFile: String = { Thread.currentThread().getContextClassLoader.getResource("text-suite.txt").toString } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 0e8c37df88..a2c8092e01 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -95,6 +95,13 @@ private[sql] class OrcFileFormat } } + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = { + true + } + override def buildReader( sparkSession: SparkSession, dataSchema: StructType, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 1fb777ade4..67a58a3859 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration -class SimpleTextSource extends FileFormat with DataSourceRegister { +class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { override def shortName(): String = "test" override def inferSchema( |