aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/readwriter.py30
-rw-r--r--python/pyspark/sql/tests.py32
2 files changed, 51 insertions, 11 deletions
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index f036644acc..1b7bc0f9a1 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -218,7 +218,10 @@ class DataFrameWriter(object):
>>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
"""
- self._jwrite = self._jwrite.mode(saveMode)
+ # At the JVM side, the default value of mode is already set to "error".
+ # So, if the given saveMode is None, we will not call JVM-side's mode method.
+ if saveMode is not None:
+ self._jwrite = self._jwrite.mode(saveMode)
return self
@since(1.4)
@@ -253,11 +256,12 @@ class DataFrameWriter(object):
"""
if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
cols = cols[0]
- self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols))
+ if len(cols) > 0:
+ self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols))
return self
@since(1.4)
- def save(self, path=None, format=None, mode="error", **options):
+ def save(self, path=None, format=None, mode=None, partitionBy=(), **options):
"""Saves the contents of the :class:`DataFrame` to a data source.
The data source is specified by the ``format`` and a set of ``options``.
@@ -272,11 +276,12 @@ class DataFrameWriter(object):
* ``overwrite``: Overwrite existing data.
* ``ignore``: Silently ignore this operation if data already exists.
* ``error`` (default case): Throw an exception if data already exists.
+ :param partitionBy: names of partitioning columns
:param options: all other string options
>>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
"""
- self.mode(mode).options(**options)
+ self.partitionBy(partitionBy).mode(mode).options(**options)
if format is not None:
self.format(format)
if path is None:
@@ -296,7 +301,7 @@ class DataFrameWriter(object):
self._jwrite.mode("overwrite" if overwrite else "append").insertInto(tableName)
@since(1.4)
- def saveAsTable(self, name, format=None, mode="error", **options):
+ def saveAsTable(self, name, format=None, mode=None, partitionBy=(), **options):
"""Saves the content of the :class:`DataFrame` as the specified table.
In the case the table already exists, behavior of this function depends on the
@@ -312,15 +317,16 @@ class DataFrameWriter(object):
:param name: the table name
:param format: the format used to save
:param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
+ :param partitionBy: names of partitioning columns
:param options: all other string options
"""
- self.mode(mode).options(**options)
+ self.partitionBy(partitionBy).mode(mode).options(**options)
if format is not None:
self.format(format)
self._jwrite.saveAsTable(name)
@since(1.4)
- def json(self, path, mode="error"):
+ def json(self, path, mode=None):
"""Saves the content of the :class:`DataFrame` in JSON format at the specified path.
:param path: the path in any Hadoop supported file system
@@ -333,10 +339,10 @@ class DataFrameWriter(object):
>>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
"""
- self._jwrite.mode(mode).json(path)
+ self.mode(mode)._jwrite.json(path)
@since(1.4)
- def parquet(self, path, mode="error"):
+ def parquet(self, path, mode=None, partitionBy=()):
"""Saves the content of the :class:`DataFrame` in Parquet format at the specified path.
:param path: the path in any Hadoop supported file system
@@ -346,13 +352,15 @@ class DataFrameWriter(object):
* ``overwrite``: Overwrite existing data.
* ``ignore``: Silently ignore this operation if data already exists.
* ``error`` (default case): Throw an exception if data already exists.
+ :param partitionBy: names of partitioning columns
>>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data'))
"""
- self._jwrite.mode(mode).parquet(path)
+ self.partitionBy(partitionBy).mode(mode)
+ self._jwrite.parquet(path)
@since(1.4)
- def jdbc(self, url, table, mode="error", properties={}):
+ def jdbc(self, url, table, mode=None, properties={}):
"""Saves the content of the :class:`DataFrame` to a external database table via JDBC.
.. note:: Don't create too many partitions in parallel on a large cluster;\
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index b5fbb7d098..13f4556943 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -539,6 +539,38 @@ class SQLTests(ReusedPySparkTestCase):
shutil.rmtree(tmpPath)
+ def test_save_and_load_builder(self):
+ df = self.df
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ df.write.json(tmpPath)
+ actual = self.sqlCtx.read.json(tmpPath)
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+ schema = StructType([StructField("value", StringType(), True)])
+ actual = self.sqlCtx.read.json(tmpPath, schema)
+ self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
+
+ df.write.mode("overwrite").json(tmpPath)
+ actual = self.sqlCtx.read.json(tmpPath)
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+ df.write.mode("overwrite").options(noUse="this options will not be used in save.")\
+ .format("json").save(path=tmpPath)
+ actual =\
+ self.sqlCtx.read.format("json")\
+ .load(path=tmpPath, noUse="this options will not be used in load.")
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+ defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
+ actual = self.sqlCtx.load(path=tmpPath)
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+ self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
+
+ shutil.rmtree(tmpPath)
+
def test_help_command(self):
# Regression test for SPARK-5464
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])