diff options
Diffstat (limited to 'sql')
4 files changed, 104 insertions, 3 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 32067011c3..e75e7d2770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -197,10 +197,15 @@ case class DataSource( SparkHadoopUtil.get.globPathIfNecessary(qualified) }.toArray val fileCatalog = new ListingFileCatalog(sparkSession, globbedPaths, options, None) - format.inferSchema( + val partitionCols = fileCatalog.partitionSpec().partitionColumns.fields + val inferred = format.inferSchema( sparkSession, caseInsensitiveOptions, fileCatalog.allFiles()) + + inferred.map { inferredSchema => + StructType(inferredSchema ++ partitionCols) + } }.getOrElse { throw new AnalysisException("Unable to infer schema. It must be specified manually.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index be023273db..614a6261e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -47,6 +47,13 @@ class FileStreamSource( fs.makeQualified(new Path(path)) // can contains glob patterns } + private val optionsWithPartitionBasePath = sourceOptions.optionMapWithoutPath ++ { + if (!SparkHadoopUtil.get.isGlobPath(new Path(path)) && options.contains("path")) { + Map("basePath" -> path) + } else { + Map() + }} + private val metadataLog = new FileStreamSourceLog(FileStreamSourceLog.VERSION, sparkSession, metadataPath) private var maxBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L) @@ -136,7 +143,7 @@ class FileStreamSource( paths = files.map(_.path), userSpecifiedSchema = Some(schema), className = fileFormatClassName, - options = sourceOptions.optionMapWithoutPath) + options = optionsWithPartitionBasePath) Dataset.ofRows(sparkSession, LogicalRelation(newDataSource.resolveRelation( checkFilesExist = false))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 55c95ae285..3157afe5a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -102,6 +102,12 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext with Private } } + case class DeleteFile(file: File) extends ExternalAction { + def runAction(): Unit = { + Utils.deleteRecursively(file) + } + } + /** Use `format` and `path` to create FileStreamSource via DataFrameReader */ def createFileStream( format: String, @@ -608,6 +614,81 @@ class FileStreamSourceSuite extends FileStreamSourceTest { // =============== other tests ================ + test("read new files in partitioned table without globbing, should read partition data") { + withTempDirs { case (dir, tmp) => + val partitionFooSubDir = new File(dir, "partition=foo") + val partitionBarSubDir = new File(dir, "partition=bar") + + val schema = new StructType().add("value", StringType).add("partition", StringType) + val fileStream = createFileStream("json", s"${dir.getCanonicalPath}", Some(schema)) + val filtered = fileStream.filter($"value" contains "keep") + testStream(filtered)( + // Create new partition=foo sub dir and write to it + AddTextFileData("{'value': 'drop1'}\n{'value': 'keep2'}", partitionFooSubDir, tmp), + CheckAnswer(("keep2", "foo")), + + // Append to same partition=foo sub dir + AddTextFileData("{'value': 'keep3'}", partitionFooSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo")), + + // Create new partition sub dir and write to it + AddTextFileData("{'value': 'keep4'}", partitionBarSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo"), ("keep4", "bar")), + + // Append to same partition=bar sub dir + AddTextFileData("{'value': 'keep5'}", partitionBarSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo"), ("keep4", "bar"), ("keep5", "bar")) + ) + } + } + + test("when schema inference is turned on, should read partition data") { + def createFile(content: String, src: File, tmp: File): Unit = { + val tempFile = Utils.tempFileWith(new File(tmp, "text")) + val finalFile = new File(src, tempFile.getName) + src.mkdirs() + require(stringToFile(tempFile, content).renameTo(finalFile)) + } + + withSQLConf(SQLConf.STREAMING_SCHEMA_INFERENCE.key -> "true") { + withTempDirs { case (dir, tmp) => + val partitionFooSubDir = new File(dir, "partition=foo") + val partitionBarSubDir = new File(dir, "partition=bar") + + // Create file in partition, so we can infer the schema. + createFile("{'value': 'drop0'}", partitionFooSubDir, tmp) + + val fileStream = createFileStream("json", s"${dir.getCanonicalPath}") + val filtered = fileStream.filter($"value" contains "keep") + testStream(filtered)( + // Append to same partition=foo sub dir + AddTextFileData("{'value': 'drop1'}\n{'value': 'keep2'}", partitionFooSubDir, tmp), + CheckAnswer(("keep2", "foo")), + + // Append to same partition=foo sub dir + AddTextFileData("{'value': 'keep3'}", partitionFooSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo")), + + // Create new partition sub dir and write to it + AddTextFileData("{'value': 'keep4'}", partitionBarSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo"), ("keep4", "bar")), + + // Append to same partition=bar sub dir + AddTextFileData("{'value': 'keep5'}", partitionBarSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo"), ("keep4", "bar"), ("keep5", "bar")), + + // Delete the two partition dirs + DeleteFile(partitionFooSubDir), + DeleteFile(partitionBarSubDir), + + AddTextFileData("{'value': 'keep6'}", partitionBarSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo"), ("keep4", "bar"), ("keep5", "bar"), + ("keep6", "bar")) + ) + } + } + } + test("fault tolerance") { withTempDirs { case (src, tmp) => val fileStream = createFileStream("text", src.getCanonicalPath) @@ -792,7 +873,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } assert(src.listFiles().size === numFiles) - val files = spark.readStream.text(root.getCanonicalPath).as[String] + val files = spark.readStream.text(root.getCanonicalPath).as[(String, Int)] // Note this query will use constant folding to eliminate the file scan. // This is to avoid actually running a Spark job with 10000 tasks diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 6c5b170d9c..aa6515bc7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -95,6 +95,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { def addData(query: Option[StreamExecution]): (Source, Offset) } + /** A trait that can be extended when testing a source. */ + trait ExternalAction extends StreamAction { + def runAction(): Unit + } + case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData { override def toString: String = s"AddData to $source: ${data.mkString(",")}" @@ -429,6 +434,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { failTest("Error adding data", e) } + case e: ExternalAction => + e.runAction() + case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) => verify(currentStream != null, "stream not running") // Get the map of source index to the current source objects |