aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2016-06-20 14:52:28 -0700
committerShixiong Zhu <shixiong@databricks.com>2016-06-20 14:52:28 -0700
commitb99129cc452defc266f6d357f5baab5f4ff37a36 (patch)
treede4e6e356930aeacee94b541530be063d178707c /sql
parent6df8e3886063a9d8c2e8499456ea9166245d5640 (diff)
downloadspark-b99129cc452defc266f6d357f5baab5f4ff37a36.tar.gz
spark-b99129cc452defc266f6d357f5baab5f4ff37a36.tar.bz2
spark-b99129cc452defc266f6d357f5baab5f4ff37a36.zip
[SPARK-15982][SPARK-16009][SPARK-16007][SQL] Harmonize the behavior of DataFrameReader.text/csv/json/parquet/orc
## What changes were proposed in this pull request? Issues with current reader behavior. - `text()` without args returns an empty DF with no columns -> inconsistent, its expected that text will always return a DF with `value` string field, - `textFile()` without args fails with exception because of the above reason, it expected the DF returned by `text()` to have a `value` field. - `orc()` does not have var args, inconsistent with others - `json(single-arg)` was removed, but that caused source compatibility issues - [SPARK-16009](https://issues.apache.org/jira/browse/SPARK-16009) - user specified schema was not respected when `text/csv/...` were used with no args - [SPARK-16007](https://issues.apache.org/jira/browse/SPARK-16007) The solution I am implementing is to do the following. - For each format, there will be a single argument method, and a vararg method. For json, parquet, csv, text, this means adding json(string), etc.. For orc, this means adding orc(varargs). - Remove the special handling of text(), csv(), etc. that returns empty dataframe with no fields. Rather pass on the empty sequence of paths to the datasource, and let each datasource handle it right. For e.g, text data source, should return empty DF with schema (value: string) - Deduped docs and fixed their formatting. ## How was this patch tested? Added new unit tests for Scala and Java tests Author: Tathagata Das <tathagata.das1565@gmail.com> Closes #13727 from tdas/SPARK-15982.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala132
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameReaderWriterSuite.java158
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala186
3 files changed, 420 insertions, 56 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 2ae854d04f..841503b260 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -119,13 +119,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* @since 1.4.0
*/
def load(): DataFrame = {
- val dataSource =
- DataSource(
- sparkSession,
- userSpecifiedSchema = userSpecifiedSchema,
- className = source,
- options = extraOptions.toMap)
- Dataset.ofRows(sparkSession, LogicalRelation(dataSource.resolveRelation()))
+ load(Seq.empty: _*) // force invocation of `load(...varargs...)`
}
/**
@@ -135,7 +129,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* @since 1.4.0
*/
def load(path: String): DataFrame = {
- option("path", path).load()
+ load(Seq(path): _*) // force invocation of `load(...varargs...)`
}
/**
@@ -146,18 +140,15 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
*/
@scala.annotation.varargs
def load(paths: String*): DataFrame = {
- if (paths.isEmpty) {
- sparkSession.emptyDataFrame
- } else {
- sparkSession.baseRelationToDataFrame(
- DataSource.apply(
- sparkSession,
- paths = paths,
- userSpecifiedSchema = userSpecifiedSchema,
- className = source,
- options = extraOptions.toMap).resolveRelation())
- }
+ sparkSession.baseRelationToDataFrame(
+ DataSource.apply(
+ sparkSession,
+ paths = paths,
+ userSpecifiedSchema = userSpecifiedSchema,
+ className = source,
+ options = extraOptions.toMap).resolveRelation())
}
+
/**
* Construct a [[DataFrame]] representing the database table accessible via JDBC URL
* url named table and connection properties.
@@ -247,11 +238,23 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
/**
* Loads a JSON file (one object per line) and returns the result as a [[DataFrame]].
+ * See the documentation on the overloaded `json()` method with varargs for more details.
+ *
+ * @since 1.4.0
+ */
+ def json(path: String): DataFrame = {
+ // This method ensures that calls that explicit need single argument works, see SPARK-16009
+ json(Seq(path): _*)
+ }
+
+ /**
+ * Loads a JSON file (one object per line) and returns the result as a [[DataFrame]].
*
* This function goes through the input once to determine the input schema. If you know the
* schema in advance, use the version that specifies the schema to avoid the extra scan.
*
* You can set the following JSON-specific options to deal with non-standard JSON files:
+ * <ul>
* <li>`primitivesAsString` (default `false`): infers all primitive values as a string type</li>
* <li>`prefersDecimal` (default `false`): infers all floating-point values as a decimal
* type. If the values do not fit in decimal, then it infers them as doubles.</li>
@@ -266,17 +269,17 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* <li>`mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records
* during parsing.</li>
* <ul>
- * <li>`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts the
- * malformed string into a new field configured by `columnNameOfCorruptRecord`. When
+ * <li> - `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts
+ * the malformed string into a new field configured by `columnNameOfCorruptRecord`. When
* a schema is set by user, it sets `null` for extra fields.</li>
- * <li>`DROPMALFORMED` : ignores the whole corrupted records.</li>
- * <li>`FAILFAST` : throws an exception when it meets corrupted records.</li>
+ * <li> - `DROPMALFORMED` : ignores the whole corrupted records.</li>
+ * <li> - `FAILFAST` : throws an exception when it meets corrupted records.</li>
* </ul>
* <li>`columnNameOfCorruptRecord` (default is the value specified in
* `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string
* created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.</li>
- *
- * @since 1.6.0
+ * </ul>
+ * @since 2.0.0
*/
@scala.annotation.varargs
def json(paths: String*): DataFrame = format("json").load(paths : _*)
@@ -327,6 +330,17 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
}
/**
+ * Loads a CSV file and returns the result as a [[DataFrame]]. See the documentation on the
+ * other overloaded `csv()` method for more details.
+ *
+ * @since 2.0.0
+ */
+ def csv(path: String): DataFrame = {
+ // This method ensures that calls that explicit need single argument works, see SPARK-16009
+ csv(Seq(path): _*)
+ }
+
+ /**
* Loads a CSV file and returns the result as a [[DataFrame]].
*
* This function will go through the input once to determine the input schema if `inferSchema`
@@ -334,6 +348,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* specify the schema explicitly using [[schema]].
*
* You can set the following CSV-specific options to deal with CSV files:
+ * <ul>
* <li>`sep` (default `,`): sets the single character as a separator for each
* field and value.</li>
* <li>`encoding` (default `UTF-8`): decodes the CSV files by the given encoding
@@ -370,26 +385,37 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* <li>`mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records
* during parsing.</li>
* <ul>
- * <li>`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When
+ * <li> - `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When
* a schema is set by user, it sets `null` for extra fields.</li>
- * <li>`DROPMALFORMED` : ignores the whole corrupted records.</li>
- * <li>`FAILFAST` : throws an exception when it meets corrupted records.</li>
+ * <li> - `DROPMALFORMED` : ignores the whole corrupted records.</li>
+ * <li> - `FAILFAST` : throws an exception when it meets corrupted records.</li>
+ * </ul>
* </ul>
- *
* @since 2.0.0
*/
@scala.annotation.varargs
def csv(paths: String*): DataFrame = format("csv").load(paths : _*)
/**
- * Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty
- * [[DataFrame]] if no paths are passed in.
+ * Loads a Parquet file, returning the result as a [[DataFrame]]. See the documentation
+ * on the other overloaded `parquet()` method for more details.
+ *
+ * @since 2.0.0
+ */
+ def parquet(path: String): DataFrame = {
+ // This method ensures that calls that explicit need single argument works, see SPARK-16009
+ parquet(Seq(path): _*)
+ }
+
+ /**
+ * Loads a Parquet file, returning the result as a [[DataFrame]].
*
* You can set the following Parquet-specific option(s) for reading Parquet files:
+ * <ul>
* <li>`mergeSchema` (default is the value specified in `spark.sql.parquet.mergeSchema`): sets
* whether we should merge schemas collected from all Parquet part-files. This will override
* `spark.sql.parquet.mergeSchema`.</li>
- *
+ * </ul>
* @since 1.4.0
*/
@scala.annotation.varargs
@@ -404,7 +430,20 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* @since 1.5.0
* @note Currently, this method can only be used after enabling Hive support.
*/
- def orc(path: String): DataFrame = format("orc").load(path)
+ def orc(path: String): DataFrame = {
+ // This method ensures that calls that explicit need single argument works, see SPARK-16009
+ orc(Seq(path): _*)
+ }
+
+ /**
+ * Loads an ORC file and returns the result as a [[DataFrame]].
+ *
+ * @param paths input paths
+ * @since 2.0.0
+ * @note Currently, this method can only be used after enabling Hive support.
+ */
+ @scala.annotation.varargs
+ def orc(paths: String*): DataFrame = format("orc").load(paths: _*)
/**
* Returns the specified table as a [[DataFrame]].
@@ -419,6 +458,18 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
/**
* Loads text files and returns a [[DataFrame]] whose schema starts with a string column named
+ * "value", and followed by partitioned columns if there are any. See the documentation on
+ * the other overloaded `text()` method for more details.
+ *
+ * @since 2.0.0
+ */
+ def text(path: String): DataFrame = {
+ // This method ensures that calls that explicit need single argument works, see SPARK-16009
+ text(Seq(path): _*)
+ }
+
+ /**
+ * Loads text files and returns a [[DataFrame]] whose schema starts with a string column named
* "value", and followed by partitioned columns if there are any.
*
* Each line in the text files is a new row in the resulting DataFrame. For example:
@@ -430,13 +481,23 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* spark.read().text("/path/to/spark/README.md")
* }}}
*
- * @param paths input path
+ * @param paths input paths
* @since 1.6.0
*/
@scala.annotation.varargs
def text(paths: String*): DataFrame = format("text").load(paths : _*)
/**
+ * Loads text files and returns a [[Dataset]] of String. See the documentation on the
+ * other overloaded `textFile()` method for more details.
+ * @since 2.0.0
+ */
+ def textFile(path: String): Dataset[String] = {
+ // This method ensures that calls that explicit need single argument works, see SPARK-16009
+ textFile(Seq(path): _*)
+ }
+
+ /**
* Loads text files and returns a [[Dataset]] of String. The underlying schema of the Dataset
* contains a single string column named "value".
*
@@ -457,6 +518,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
*/
@scala.annotation.varargs
def textFile(paths: String*): Dataset[String] = {
+ if (userSpecifiedSchema.nonEmpty) {
+ throw new AnalysisException("User specified schema not supported with `textFile`")
+ }
text(paths : _*).select("value").as[String](sparkSession.implicits.newStringEncoder)
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameReaderWriterSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameReaderWriterSuite.java
new file mode 100644
index 0000000000..7babf7573c
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameReaderWriterSuite.java
@@ -0,0 +1,158 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package test.org.apache.spark.sql;
+
+import java.io.File;
+import java.util.HashMap;
+
+import org.apache.spark.sql.SaveMode;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.test.TestSparkSession;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.util.Utils;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+public class JavaDataFrameReaderWriterSuite {
+ private SparkSession spark = new TestSparkSession();
+ private StructType schema = new StructType().add("s", "string");
+ private transient String input;
+ private transient String output;
+
+ @Before
+ public void setUp() {
+ input = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "input").toString();
+ File f = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "output");
+ f.delete();
+ output = f.toString();
+ }
+
+ @After
+ public void tearDown() {
+ spark.stop();
+ spark = null;
+ }
+
+ @Test
+ public void testFormatAPI() {
+ spark
+ .read()
+ .format("org.apache.spark.sql.test")
+ .load()
+ .write()
+ .format("org.apache.spark.sql.test")
+ .save();
+ }
+
+ @Test
+ public void testOptionsAPI() {
+ HashMap<String, String> map = new HashMap<String, String>();
+ map.put("e", "1");
+ spark
+ .read()
+ .option("a", "1")
+ .option("b", 1)
+ .option("c", 1.0)
+ .option("d", true)
+ .options(map)
+ .text()
+ .write()
+ .option("a", "1")
+ .option("b", 1)
+ .option("c", 1.0)
+ .option("d", true)
+ .options(map)
+ .format("org.apache.spark.sql.test")
+ .save();
+ }
+
+ @Test
+ public void testSaveModeAPI() {
+ spark
+ .range(10)
+ .write()
+ .format("org.apache.spark.sql.test")
+ .mode(SaveMode.ErrorIfExists)
+ .save();
+ }
+
+ @Test
+ public void testLoadAPI() {
+ spark.read().format("org.apache.spark.sql.test").load();
+ spark.read().format("org.apache.spark.sql.test").load(input);
+ spark.read().format("org.apache.spark.sql.test").load(input, input, input);
+ spark.read().format("org.apache.spark.sql.test").load(new String[]{input, input});
+ }
+
+ @Test
+ public void testTextAPI() {
+ spark.read().text();
+ spark.read().text(input);
+ spark.read().text(input, input, input);
+ spark.read().text(new String[]{input, input})
+ .write().text(output);
+ }
+
+ @Test
+ public void testTextFileAPI() {
+ spark.read().textFile();
+ spark.read().textFile(input);
+ spark.read().textFile(input, input, input);
+ spark.read().textFile(new String[]{input, input});
+ }
+
+ @Test
+ public void testCsvAPI() {
+ spark.read().schema(schema).csv();
+ spark.read().schema(schema).csv(input);
+ spark.read().schema(schema).csv(input, input, input);
+ spark.read().schema(schema).csv(new String[]{input, input})
+ .write().csv(output);
+ }
+
+ @Test
+ public void testJsonAPI() {
+ spark.read().schema(schema).json();
+ spark.read().schema(schema).json(input);
+ spark.read().schema(schema).json(input, input, input);
+ spark.read().schema(schema).json(new String[]{input, input})
+ .write().json(output);
+ }
+
+ @Test
+ public void testParquetAPI() {
+ spark.read().schema(schema).parquet();
+ spark.read().schema(schema).parquet(input);
+ spark.read().schema(schema).parquet(input, input, input);
+ spark.read().schema(schema).parquet(new String[] { input, input })
+ .write().parquet(output);
+ }
+
+ /**
+ * This only tests whether API compiles, but does not run it as orc()
+ * cannot be run without Hive classes.
+ */
+ public void testOrcAPI() {
+ spark.read().schema(schema).orc();
+ spark.read().schema(schema).orc(input);
+ spark.read().schema(schema).orc(input, input, input);
+ spark.read().schema(schema).orc(new String[]{input, input})
+ .write().orc(output);
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
index 98e57b3804..3fa3864bc9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
@@ -17,6 +17,10 @@
package org.apache.spark.sql.test
+import java.io.File
+
+import org.scalatest.BeforeAndAfter
+
import org.apache.spark.sql._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{StringType, StructField, StructType}
@@ -79,10 +83,19 @@ class DefaultSource
}
-class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext {
+class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with BeforeAndAfter {
+
- private def newMetadataDir =
- Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+ private val userSchema = new StructType().add("s", StringType)
+ private val textSchema = new StructType().add("value", StringType)
+ private val data = Seq("1", "2", "3")
+ private val dir = Utils.createTempDir(namePrefix = "input").getCanonicalPath
+ private implicit var enc: Encoder[String] = _
+
+ before {
+ enc = spark.implicits.newStringEncoder
+ Utils.deleteRecursively(new File(dir))
+ }
test("writeStream cannot be called on non-streaming datasets") {
val e = intercept[AnalysisException] {
@@ -157,24 +170,6 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext {
assert(LastOptions.saveMode === SaveMode.ErrorIfExists)
}
- test("paths") {
- val df = spark.read
- .format("org.apache.spark.sql.test")
- .option("checkpointLocation", newMetadataDir)
- .load("/test")
-
- assert(LastOptions.parameters("path") == "/test")
-
- LastOptions.clear()
-
- df.write
- .format("org.apache.spark.sql.test")
- .option("checkpointLocation", newMetadataDir)
- .save("/test")
-
- assert(LastOptions.parameters("path") == "/test")
- }
-
test("test different data types for options") {
val df = spark.read
.format("org.apache.spark.sql.test")
@@ -193,7 +188,6 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext {
.option("intOpt", 56)
.option("boolOpt", false)
.option("doubleOpt", 6.7)
- .option("checkpointLocation", newMetadataDir)
.save("/test")
assert(LastOptions.parameters("intOpt") == "56")
@@ -228,4 +222,152 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext {
}
}
}
+
+ test("load API") {
+ spark.read.format("org.apache.spark.sql.test").load()
+ spark.read.format("org.apache.spark.sql.test").load(dir)
+ spark.read.format("org.apache.spark.sql.test").load(dir, dir, dir)
+ spark.read.format("org.apache.spark.sql.test").load(Seq(dir, dir): _*)
+ Option(dir).map(spark.read.format("org.apache.spark.sql.test").load)
+ }
+
+ test("text - API and behavior regarding schema") {
+ // Writer
+ spark.createDataset(data).write.mode(SaveMode.Overwrite).text(dir)
+ testRead(spark.read.text(dir), data, textSchema)
+
+ // Reader, without user specified schema
+ testRead(spark.read.text(), Seq.empty, textSchema)
+ testRead(spark.read.text(dir, dir, dir), data ++ data ++ data, textSchema)
+ testRead(spark.read.text(Seq(dir, dir): _*), data ++ data, textSchema)
+ // Test explicit calls to single arg method - SPARK-16009
+ testRead(Option(dir).map(spark.read.text).get, data, textSchema)
+
+ // Reader, with user specified schema, should just apply user schema on the file data
+ testRead(spark.read.schema(userSchema).text(), Seq.empty, userSchema)
+ testRead(spark.read.schema(userSchema).text(dir), data, userSchema)
+ testRead(spark.read.schema(userSchema).text(dir, dir), data ++ data, userSchema)
+ testRead(spark.read.schema(userSchema).text(Seq(dir, dir): _*), data ++ data, userSchema)
+ }
+
+ test("textFile - API and behavior regarding schema") {
+ spark.createDataset(data).write.mode(SaveMode.Overwrite).text(dir)
+
+ // Reader, without user specified schema
+ testRead(spark.read.textFile().toDF(), Seq.empty, textSchema)
+ testRead(spark.read.textFile(dir).toDF(), data, textSchema)
+ testRead(spark.read.textFile(dir, dir).toDF(), data ++ data, textSchema)
+ testRead(spark.read.textFile(Seq(dir, dir): _*).toDF(), data ++ data, textSchema)
+ // Test explicit calls to single arg method - SPARK-16009
+ testRead(Option(dir).map(spark.read.text).get, data, textSchema)
+
+ // Reader, with user specified schema, should just apply user schema on the file data
+ val e = intercept[AnalysisException] { spark.read.schema(userSchema).textFile() }
+ assert(e.getMessage.toLowerCase.contains("user specified schema not supported"))
+ intercept[AnalysisException] { spark.read.schema(userSchema).textFile(dir) }
+ intercept[AnalysisException] { spark.read.schema(userSchema).textFile(dir, dir) }
+ intercept[AnalysisException] { spark.read.schema(userSchema).textFile(Seq(dir, dir): _*) }
+ }
+
+ test("csv - API and behavior regarding schema") {
+ // Writer
+ spark.createDataset(data).toDF("str").write.mode(SaveMode.Overwrite).csv(dir)
+ val df = spark.read.csv(dir)
+ checkAnswer(df, spark.createDataset(data).toDF())
+ val schema = df.schema
+
+ // Reader, without user specified schema
+ intercept[IllegalArgumentException] {
+ testRead(spark.read.csv(), Seq.empty, schema)
+ }
+ testRead(spark.read.csv(dir), data, schema)
+ testRead(spark.read.csv(dir, dir), data ++ data, schema)
+ testRead(spark.read.csv(Seq(dir, dir): _*), data ++ data, schema)
+ // Test explicit calls to single arg method - SPARK-16009
+ testRead(Option(dir).map(spark.read.csv).get, data, schema)
+
+ // Reader, with user specified schema, should just apply user schema on the file data
+ testRead(spark.read.schema(userSchema).csv(), Seq.empty, userSchema)
+ testRead(spark.read.schema(userSchema).csv(dir), data, userSchema)
+ testRead(spark.read.schema(userSchema).csv(dir, dir), data ++ data, userSchema)
+ testRead(spark.read.schema(userSchema).csv(Seq(dir, dir): _*), data ++ data, userSchema)
+ }
+
+ test("json - API and behavior regarding schema") {
+ // Writer
+ spark.createDataset(data).toDF("str").write.mode(SaveMode.Overwrite).json(dir)
+ val df = spark.read.json(dir)
+ checkAnswer(df, spark.createDataset(data).toDF())
+ val schema = df.schema
+
+ // Reader, without user specified schema
+ intercept[AnalysisException] {
+ testRead(spark.read.json(), Seq.empty, schema)
+ }
+ testRead(spark.read.json(dir), data, schema)
+ testRead(spark.read.json(dir, dir), data ++ data, schema)
+ testRead(spark.read.json(Seq(dir, dir): _*), data ++ data, schema)
+ // Test explicit calls to single arg method - SPARK-16009
+ testRead(Option(dir).map(spark.read.json).get, data, schema)
+
+ // Reader, with user specified schema, data should be nulls as schema in file different
+ // from user schema
+ val expData = Seq[String](null, null, null)
+ testRead(spark.read.schema(userSchema).json(), Seq.empty, userSchema)
+ testRead(spark.read.schema(userSchema).json(dir), expData, userSchema)
+ testRead(spark.read.schema(userSchema).json(dir, dir), expData ++ expData, userSchema)
+ testRead(spark.read.schema(userSchema).json(Seq(dir, dir): _*), expData ++ expData, userSchema)
+ }
+
+ test("parquet - API and behavior regarding schema") {
+ // Writer
+ spark.createDataset(data).toDF("str").write.mode(SaveMode.Overwrite).parquet(dir)
+ val df = spark.read.parquet(dir)
+ checkAnswer(df, spark.createDataset(data).toDF())
+ val schema = df.schema
+
+ // Reader, without user specified schema
+ intercept[AnalysisException] {
+ testRead(spark.read.parquet(), Seq.empty, schema)
+ }
+ testRead(spark.read.parquet(dir), data, schema)
+ testRead(spark.read.parquet(dir, dir), data ++ data, schema)
+ testRead(spark.read.parquet(Seq(dir, dir): _*), data ++ data, schema)
+ // Test explicit calls to single arg method - SPARK-16009
+ testRead(Option(dir).map(spark.read.parquet).get, data, schema)
+
+ // Reader, with user specified schema, data should be nulls as schema in file different
+ // from user schema
+ val expData = Seq[String](null, null, null)
+ testRead(spark.read.schema(userSchema).parquet(), Seq.empty, userSchema)
+ testRead(spark.read.schema(userSchema).parquet(dir), expData, userSchema)
+ testRead(spark.read.schema(userSchema).parquet(dir, dir), expData ++ expData, userSchema)
+ testRead(
+ spark.read.schema(userSchema).parquet(Seq(dir, dir): _*), expData ++ expData, userSchema)
+ }
+
+ /**
+ * This only tests whether API compiles, but does not run it as orc()
+ * cannot be run without Hive classes.
+ */
+ ignore("orc - API") {
+ // Reader, with user specified schema
+ // Refer to csv-specific test suites for behavior without user specified schema
+ spark.read.schema(userSchema).orc()
+ spark.read.schema(userSchema).orc(dir)
+ spark.read.schema(userSchema).orc(dir, dir, dir)
+ spark.read.schema(userSchema).orc(Seq(dir, dir): _*)
+ Option(dir).map(spark.read.schema(userSchema).orc)
+
+ // Writer
+ spark.range(10).write.orc(dir)
+ }
+
+ private def testRead(
+ df: => DataFrame,
+ expectedResult: Seq[String],
+ expectedSchema: StructType): Unit = {
+ checkAnswer(df, spark.createDataset(expectedResult).toDF())
+ assert(df.schema === expectedSchema)
+ }
}