diff options
4 files changed, 116 insertions, 49 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 0dfe7dba1e..07bc8ae148 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -67,7 +67,10 @@ case class DataSource( bucketSpec: Option[BucketSpec] = None, options: Map[String, String] = Map.empty) extends Logging { + case class SourceInfo(name: String, schema: StructType) + lazy val providingClass: Class[_] = lookupDataSource(className) + lazy val sourceInfo = sourceSchema() /** A map to maintain backward compatibility in case we move data sources around. */ private val backwardCompatibilityMap = Map( @@ -145,17 +148,19 @@ case class DataSource( } /** Returns the name and schema of the source that can be used to continually read data. */ - def sourceSchema(): (String, StructType) = { + private def sourceSchema(): SourceInfo = { providingClass.newInstance() match { case s: StreamSourceProvider => - s.sourceSchema(sqlContext, userSpecifiedSchema, className, options) + val (name, schema) = s.sourceSchema(sqlContext, userSpecifiedSchema, className, options) + SourceInfo(name, schema) case format: FileFormat => val caseInsensitiveOptions = new CaseInsensitiveMap(options) val path = caseInsensitiveOptions.getOrElse("path", { throw new IllegalArgumentException("'path' is not specified") }) - (s"FileSource[$path]", inferFileFormatSchema(format)) + SourceInfo(s"FileSource[$path]", inferFileFormatSchema(format)) + case _ => throw new UnsupportedOperationException( s"Data source $className does not support streamed reading") @@ -174,24 +179,20 @@ case class DataSource( throw new IllegalArgumentException("'path' is not specified") }) - val dataSchema = inferFileFormatSchema(format) - def dataFrameBuilder(files: Array[String]): DataFrame = { - Dataset.ofRows( - sqlContext, - LogicalRelation( - DataSource( - sqlContext, - paths = files, - userSpecifiedSchema = Some(dataSchema), - className = className, - options = - new CaseInsensitiveMap( - options.filterKeys(_ != "path") + ("basePath" -> path))).resolveRelation())) + val newOptions = options.filterKeys(_ != "path") + ("basePath" -> path) + val newDataSource = + DataSource( + sqlContext, + paths = files, + userSpecifiedSchema = Some(sourceInfo.schema), + className = className, + options = new CaseInsensitiveMap(newOptions)) + Dataset.ofRows(sqlContext, LogicalRelation(newDataSource.resolveRelation())) } new FileStreamSource( - sqlContext, metadataPath, path, Some(dataSchema), className, dataFrameBuilder) + sqlContext, metadataPath, path, sourceInfo.schema, dataFrameBuilder) case _ => throw new UnsupportedOperationException( s"Data source $className does not support streamed reading") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 6448cb6e90..51c3aee835 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -35,8 +35,7 @@ class FileStreamSource( sqlContext: SQLContext, metadataPath: String, path: String, - dataSchema: Option[StructType], - providerName: String, + override val schema: StructType, dataFrameBuilder: Array[String] => DataFrame) extends Source with Logging { private val fs = new Path(path).getFileSystem(sqlContext.sparkContext.hadoopConfiguration) @@ -48,24 +47,6 @@ class FileStreamSource( files.foreach(seenFiles.add) } - /** Returns the schema of the data from this source */ - override lazy val schema: StructType = { - dataSchema.getOrElse { - val filesPresent = fetchAllFiles() - if (filesPresent.isEmpty) { - if (providerName == "text") { - // Add a default schema for "text" - new StructType().add("value", StringType) - } else { - throw new IllegalArgumentException("No schema specified") - } - } else { - // There are some existing files. Use them to infer the schema. - dataFrameBuilder(filesPresent.toArray).schema - } - } - } - /** * Returns the maximum offset that can be retrieved from the source. * @@ -118,7 +99,6 @@ class FileStreamSource( logInfo(s"Processing ${files.length} files from ${startId + 1}:$endId") logDebug(s"Streaming ${files.mkString(", ")}") dataFrameBuilder(files) - } private def fetchAllFiles(): Seq[String] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index c29291eb58..3341580fc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.execution.datasources.DataSource object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { - val (name, schema) = dataSource.sourceSchema() - StreamingRelation(dataSource, name, schema.toAttributes) + StreamingRelation( + dataSource, dataSource.sourceInfo.name, dataSource.sourceInfo.schema.toAttributes) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 64cddf0dee..45dca2fadf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.streaming import java.io.File -import org.apache.spark.sql.{AnalysisException, StreamTest} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.apache.spark.util.Utils class FileStreamSourceTest extends StreamTest with SharedSQLContext { @@ -44,20 +44,32 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext { case class AddParquetFileData( source: FileStreamSource, - content: Seq[String], + df: DataFrame, src: File, tmp: File) extends AddData { override def addData(): Offset = { source.withBatchingLocked { - val file = Utils.tempFileWith(new File(tmp, "parquet")) - content.toDS().toDF().write.parquet(file.getCanonicalPath) - file.renameTo(new File(src, file.getName)) + AddParquetFileData.writeToFile(df, src, tmp) source.currentOffset } + 1 } } + object AddParquetFileData { + def apply( + source: FileStreamSource, + seq: Seq[String], + src: File, + tmp: File): AddParquetFileData = new AddParquetFileData(source, seq.toDS().toDF(), src, tmp) + + def writeToFile(df: DataFrame, src: File, tmp: File): Unit = { + val file = Utils.tempFileWith(new File(tmp, "parquet")) + df.write.parquet(file.getCanonicalPath) + file.renameTo(new File(src, file.getName)) + } + } + /** Use `format` and `path` to create FileStreamSource via DataFrameReader */ def createFileStreamSource( format: String, @@ -78,6 +90,17 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext { }.head } + def withTempDirs(body: (File, File) => Unit) { + val src = Utils.createTempDir(namePrefix = "streaming.src") + val tmp = Utils.createTempDir(namePrefix = "streaming.tmp") + try { + body(src, tmp) + } finally { + Utils.deleteRecursively(src) + Utils.deleteRecursively(tmp) + } + } + val valueSchema = new StructType().add("value", StringType) } @@ -99,9 +122,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { reader.stream() } df.queryExecution.analyzed - .collect { case StreamingRelation(dataSource, _, _) => - dataSource.sourceSchema() - }.head._2 + .collect { case s @ StreamingRelation(dataSource, _, _) => s.schema }.head } test("FileStreamSource schema: no path") { @@ -305,6 +326,39 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { ) } + + test("reading from json files with changing schema") { + withTempDirs { case (src, tmp) => + + // Add a file so that we can infer its schema + stringToFile(new File(src, "existing"), "{'k': 'value0'}") + + val textSource = createFileStreamSource("json", src.getCanonicalPath) + + // FileStreamSource should infer the column "k" + val text = textSource.toDF() + assert(text.schema === StructType(Seq(StructField("k", StringType)))) + + // After creating DF and before starting stream, add data with different schema + // Should not affect the inferred schema any more + stringToFile(new File(src, "existing2"), "{'k': 'value1', 'v': 'new'}") + + testStream(text)( + + // Should not pick up column v in the file added before start + AddTextFileData(textSource, "{'k': 'value2'}", src, tmp), + CheckAnswer("value0", "value1", "value2"), + + // Should read data in column k, and ignore v + AddTextFileData(textSource, "{'k': 'value3', 'v': 'new'}", src, tmp), + CheckAnswer("value0", "value1", "value2", "value3"), + + // Should ignore rows that do not have the necessary k column + AddTextFileData(textSource, "{'v': 'value4'}", src, tmp), + CheckAnswer("value0", "value1", "value2", "value3", null)) + } + } + test("read from parquet files") { val src = Utils.createTempDir(namePrefix = "streaming.src") val tmp = Utils.createTempDir(namePrefix = "streaming.tmp") @@ -327,6 +381,38 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { Utils.deleteRecursively(tmp) } + test("read from parquet files with changing schema") { + + withTempDirs { case (src, tmp) => + // Add a file so that we can infer its schema + AddParquetFileData.writeToFile(Seq("value0").toDF("k"), src, tmp) + + val fileSource = createFileStreamSource("parquet", src.getCanonicalPath) + val parquetData = fileSource.toDF() + + // FileStreamSource should infer the column "k" + assert(parquetData.schema === StructType(Seq(StructField("k", StringType)))) + + // After creating DF and before starting stream, add data with different schema + // Should not affect the inferred schema any more + AddParquetFileData.writeToFile(Seq(("value1", 0)).toDF("k", "v"), src, tmp) + + testStream(parquetData)( + // Should not pick up column v in the file added before start + AddParquetFileData(fileSource, Seq("value2").toDF("k"), src, tmp), + CheckAnswer("value0", "value1", "value2"), + + // Should read data in column k, and ignore v + AddParquetFileData(fileSource, Seq(("value3", 1)).toDF("k", "v"), src, tmp), + CheckAnswer("value0", "value1", "value2", "value3"), + + // Should ignore rows that do not have the necessary k column + AddParquetFileData(fileSource, Seq("value5").toDF("v"), src, tmp), + CheckAnswer("value0", "value1", "value2", "value3", null) + ) + } + } + test("file stream source without schema") { val src = Utils.createTempDir(namePrefix = "streaming.src") |