aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-11-24 18:16:07 -0800
committerReynold Xin <rxin@databricks.com>2015-11-24 18:16:07 -0800
commit25bbd3c16e8e8be4d2c43000223d54650e9a3696 (patch)
tree36a82e11908fa11f1a45447d6e493b80af240a2a
parent238ae51b66ac12d15fba6aff061804004c5ca6cb (diff)
downloadspark-25bbd3c16e8e8be4d2c43000223d54650e9a3696.tar.gz
spark-25bbd3c16e8e8be4d2c43000223d54650e9a3696.tar.bz2
spark-25bbd3c16e8e8be4d2c43000223d54650e9a3696.zip
[SPARK-11967][SQL] Consistent use of varargs for multiple paths in DataFrameReader
This patch makes it consistent to use varargs in all DataFrameReader methods, including Parquet, JSON, text, and the generic load function. Also added a few more API tests for the Java API. Author: Reynold Xin <rxin@databricks.com> Closes #9945 from rxin/SPARK-11967.
-rw-r--r--python/pyspark/sql/readwriter.py19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala36
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java23
-rw-r--r--sql/core/src/test/resources/text-suite2.txt1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala2
5 files changed, 66 insertions, 15 deletions
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index e8f0d7ec77..2e75f0c8a1 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -109,7 +109,7 @@ class DataFrameReader(object):
def load(self, path=None, format=None, schema=None, **options):
"""Loads data from a data source and returns it as a :class`DataFrame`.
- :param path: optional string for file-system backed data sources.
+ :param path: optional string or a list of string for file-system backed data sources.
:param format: optional string for format of the data source. Default to 'parquet'.
:param schema: optional :class:`StructType` for the input schema.
:param options: all other string options
@@ -118,6 +118,7 @@ class DataFrameReader(object):
... opt2=1, opt3='str')
>>> df.dtypes
[('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
+
>>> df = sqlContext.read.format('json').load(['python/test_support/sql/people.json',
... 'python/test_support/sql/people1.json'])
>>> df.dtypes
@@ -130,10 +131,8 @@ class DataFrameReader(object):
self.options(**options)
if path is not None:
if type(path) == list:
- paths = path
- gateway = self._sqlContext._sc._gateway
- jpaths = utils.toJArray(gateway, gateway.jvm.java.lang.String, paths)
- return self._df(self._jreader.load(jpaths))
+ return self._df(
+ self._jreader.load(self._sqlContext._sc._jvm.PythonUtils.toSeq(path)))
else:
return self._df(self._jreader.load(path))
else:
@@ -175,6 +174,8 @@ class DataFrameReader(object):
self.schema(schema)
if isinstance(path, basestring):
return self._df(self._jreader.json(path))
+ elif type(path) == list:
+ return self._df(self._jreader.json(self._sqlContext._sc._jvm.PythonUtils.toSeq(path)))
elif isinstance(path, RDD):
return self._df(self._jreader.json(path._jrdd))
else:
@@ -205,16 +206,20 @@ class DataFrameReader(object):
@ignore_unicode_prefix
@since(1.6)
- def text(self, path):
+ def text(self, paths):
"""Loads a text file and returns a [[DataFrame]] with a single string column named "text".
Each line in the text file is a new row in the resulting DataFrame.
+ :param paths: string, or list of strings, for input path(s).
+
>>> df = sqlContext.read.text('python/test_support/sql/text-test.txt')
>>> df.collect()
[Row(value=u'hello'), Row(value=u'this')]
"""
- return self._df(self._jreader.text(path))
+ if isinstance(paths, basestring):
+ paths = [paths]
+ return self._df(self._jreader.text(self._sqlContext._sc._jvm.PythonUtils.toSeq(paths)))
@since(1.5)
def orc(self, path):
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index dcb3737b70..3ed1e55ade 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -24,17 +24,17 @@ import scala.collection.JavaConverters._
import org.apache.hadoop.fs.Path
import org.apache.hadoop.util.StringUtils
+import org.apache.spark.{Logging, Partition}
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.SqlParser
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation}
-import org.apache.spark.sql.execution.datasources.json.{JSONOptions, JSONRelation}
+import org.apache.spark.sql.execution.datasources.json.JSONRelation
import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource}
import org.apache.spark.sql.types.StructType
-import org.apache.spark.{Logging, Partition}
-import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier}
/**
* :: Experimental ::
@@ -104,6 +104,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
*
* @since 1.4.0
*/
+ // TODO: Remove this one in Spark 2.0.
def load(path: String): DataFrame = {
option("path", path).load()
}
@@ -130,7 +131,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
*
* @since 1.6.0
*/
- def load(paths: Array[String]): DataFrame = {
+ @scala.annotation.varargs
+ def load(paths: String*): DataFrame = {
option("paths", paths.map(StringUtils.escapeString(_, '\\', ',')).mkString(",")).load()
}
@@ -236,12 +238,31 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
* <li>`allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers
* (e.g. 00012)</li>
*
- * @param path input path
* @since 1.4.0
*/
+ // TODO: Remove this one in Spark 2.0.
def json(path: String): DataFrame = format("json").load(path)
/**
+ * Loads a JSON file (one object per line) and returns the result as a [[DataFrame]].
+ *
+ * This function goes through the input once to determine the input schema. If you know the
+ * schema in advance, use the version that specifies the schema to avoid the extra scan.
+ *
+ * You can set the following JSON-specific options to deal with non-standard JSON files:
+ * <li>`primitivesAsString` (default `false`): infers all primitive values as a string type</li>
+ * <li>`allowComments` (default `false`): ignores Java/C++ style comment in JSON records</li>
+ * <li>`allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names</li>
+ * <li>`allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes
+ * </li>
+ * <li>`allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers
+ * (e.g. 00012)</li>
+ *
+ * @since 1.6.0
+ */
+ def json(paths: String*): DataFrame = format("json").load(paths : _*)
+
+ /**
* Loads an `JavaRDD[String]` storing JSON objects (one object per record) and
* returns the result as a [[DataFrame]].
*
@@ -328,10 +349,11 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
* sqlContext.read().text("/path/to/spark/README.md")
* }}}
*
- * @param path input path
+ * @param paths input path
* @since 1.6.0
*/
- def text(path: String): DataFrame = format("text").load(path)
+ @scala.annotation.varargs
+ def text(paths: String*): DataFrame = format("text").load(paths : _*)
///////////////////////////////////////////////////////////////////////////////////////
// Builder pattern config options
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index a12fed3c0c..8e0b2dbca4 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -298,4 +298,27 @@ public class JavaDataFrameSuite {
Assert.assertEquals(48000.0, actual[1].getDouble(1), 0.01);
Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01);
}
+
+ public void testGenericLoad() {
+ DataFrame df1 = context.read().format("text").load(
+ Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString());
+ Assert.assertEquals(4L, df1.count());
+
+ DataFrame df2 = context.read().format("text").load(
+ Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(),
+ Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString());
+ Assert.assertEquals(5L, df2.count());
+ }
+
+ @Test
+ public void testTextLoad() {
+ DataFrame df1 = context.read().text(
+ Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString());
+ Assert.assertEquals(4L, df1.count());
+
+ DataFrame df2 = context.read().text(
+ Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(),
+ Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString());
+ Assert.assertEquals(5L, df2.count());
+ }
}
diff --git a/sql/core/src/test/resources/text-suite2.txt b/sql/core/src/test/resources/text-suite2.txt
new file mode 100644
index 0000000000..f9d498c804
--- /dev/null
+++ b/sql/core/src/test/resources/text-suite2.txt
@@ -0,0 +1 @@
+This is another file for testing multi path loading.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index dd6d06512f..76e9648aa7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -897,7 +897,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val dir2 = new File(dir, "dir2").getCanonicalPath
df2.write.format("json").save(dir2)
- checkAnswer(sqlContext.read.format("json").load(Array(dir1, dir2)),
+ checkAnswer(sqlContext.read.format("json").load(dir1, dir2),
Row(1, 22) :: Row(2, 23) :: Nil)
checkAnswer(sqlContext.read.format("json").load(dir1),