aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala32
2 files changed, 34 insertions, 3 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 6dc27c1952..f572b93991 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -485,12 +485,11 @@ case class DataSource(
data.logicalPlan,
mode)
sparkSession.sessionState.executePlan(plan).toRdd
+ // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring it.
+ copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation()
case _ =>
sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.")
}
-
- // We replace the schema with that of the DataFrame we just wrote out to avoid re-inferring it.
- copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation()
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
index d454100ccb..05935cec4b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
@@ -82,6 +82,29 @@ class DefaultSource
}
}
+/** Dummy provider with only RelationProvider and CreatableRelationProvider. */
+class DefaultSourceWithoutUserSpecifiedSchema
+ extends RelationProvider
+ with CreatableRelationProvider {
+
+ case class FakeRelation(sqlContext: SQLContext) extends BaseRelation {
+ override def schema: StructType = StructType(Seq(StructField("a", StringType)))
+ }
+
+ override def createRelation(
+ sqlContext: SQLContext,
+ parameters: Map[String, String]): BaseRelation = {
+ FakeRelation(sqlContext)
+ }
+
+ override def createRelation(
+ sqlContext: SQLContext,
+ mode: SaveMode,
+ parameters: Map[String, String],
+ data: DataFrame): BaseRelation = {
+ FakeRelation(sqlContext)
+ }
+}
class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with BeforeAndAfter {
@@ -120,6 +143,15 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
.save()
}
+ test("resolve default source without extending SchemaRelationProvider") {
+ spark.read
+ .format("org.apache.spark.sql.test.DefaultSourceWithoutUserSpecifiedSchema")
+ .load()
+ .write
+ .format("org.apache.spark.sql.test.DefaultSourceWithoutUserSpecifiedSchema")
+ .save()
+ }
+
test("resolve full class") {
spark.read
.format("org.apache.spark.sql.test.DefaultSource")