aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-02-10 17:29:52 -0800
committerMichael Armbrust <michael@databricks.com>2015-02-10 17:29:52 -0800
commitaaf50d05c7616e4f8f16654b642500ae06cdd774 (patch)
tree7f30e0d08e4f2b531ac62c82a4361a2db577932d /python/pyspark
parented167e70c6d355f39b366ea0d3b92dd26d826a0b (diff)
downloadspark-aaf50d05c7616e4f8f16654b642500ae06cdd774.tar.gz
spark-aaf50d05c7616e4f8f16654b642500ae06cdd774.tar.bz2
spark-aaf50d05c7616e4f8f16654b642500ae06cdd774.zip
[SPARK-5658][SQL] Finalize DDL and write support APIs
https://issues.apache.org/jira/browse/SPARK-5658 Author: Yin Huai <yhuai@databricks.com> This patch had conflicts when merged, resolved by Committer: Michael Armbrust <michael@databricks.com> Closes #4446 from yhuai/writeSupportFollowup and squashes the following commits: f3a96f7 [Yin Huai] davies's comments. 225ff71 [Yin Huai] Use Scala TestHiveContext to initialize the Python HiveContext in Python tests. 2306f93 [Yin Huai] Style. 2091fcd [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup 537e28f [Yin Huai] Correctly clean up temp data. ae4649e [Yin Huai] Fix Python test. 609129c [Yin Huai] Doc format. 92b6659 [Yin Huai] Python doc and other minor updates. cbc717f [Yin Huai] Rename dataSourceName to source. d1c12d3 [Yin Huai] No need to delete the duplicate rule since it has been removed in master. 22cfa70 [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup d91ecb8 [Yin Huai] Fix test. 4c76d78 [Yin Huai] Simplify APIs. 3abc215 [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup 0832ce4 [Yin Huai] Fix test. 98e7cdb [Yin Huai] Python style. 2bf44ef [Yin Huai] Python APIs. c204967 [Yin Huai] Format a10223d [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup 9ff97d8 [Yin Huai] Add SaveMode to saveAsTable. 9b6e570 [Yin Huai] Update doc. c2be775 [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup 99950a2 [Yin Huai] Use Java enum for SaveMode. 4679665 [Yin Huai] Remove duplicate rule. 77d89dc [Yin Huai] Update doc. e04d908 [Yin Huai] Move import and add (Scala-specific) to scala APIs. cf5703d [Yin Huai] Add checkAnswer to Java tests. 7db95ff [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup 6dfd386 [Yin Huai] Add java test. f2f33ef [Yin Huai] Fix test. e702386 [Yin Huai] Apache header. b1e9b1b [Yin Huai] Format. ed4e1b4 [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup af9e9b3 [Yin Huai] DDL and write support API followup. 2a6213a [Yin Huai] Update API names. e6a0b77 [Yin Huai] Update test. 43bae01 [Yin Huai] Remove createTable from HiveContext. 5ffc372 [Yin Huai] Add more load APIs to SQLContext. 5390743 [Yin Huai] Add more save APIs to DataFrame.
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/sql/context.py68
-rw-r--r--python/pyspark/sql/dataframe.py72
-rw-r--r--python/pyspark/sql/tests.py107
3 files changed, 241 insertions, 6 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 49f016a9cf..882c0f98ea 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -21,6 +21,7 @@ from array import array
from itertools import imap
from py4j.protocol import Py4JError
+from py4j.java_collections import MapConverter
from pyspark.rdd import _prepare_for_python_RDD
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
@@ -87,6 +88,18 @@ class SQLContext(object):
self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
return self._scala_SQLContext
+ def setConf(self, key, value):
+ """Sets the given Spark SQL configuration property.
+ """
+ self._ssql_ctx.setConf(key, value)
+
+ def getConf(self, key, defaultValue):
+ """Returns the value of Spark SQL configuration property for the given key.
+
+ If the key is not set, returns defaultValue.
+ """
+ return self._ssql_ctx.getConf(key, defaultValue)
+
def registerFunction(self, name, f, returnType=StringType()):
"""Registers a lambda function as a UDF so it can be used in SQL statements.
@@ -455,6 +468,61 @@ class SQLContext(object):
df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
return DataFrame(df, self)
+ def load(self, path=None, source=None, schema=None, **options):
+ """Returns the dataset in a data source as a DataFrame.
+
+ The data source is specified by the `source` and a set of `options`.
+ If `source` is not specified, the default data source configured by
+ spark.sql.sources.default will be used.
+
+ Optionally, a schema can be provided as the schema of the returned DataFrame.
+ """
+ if path is not None:
+ options["path"] = path
+ if source is None:
+ source = self.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ joptions = MapConverter().convert(options,
+ self._sc._gateway._gateway_client)
+ if schema is None:
+ df = self._ssql_ctx.load(source, joptions)
+ else:
+ if not isinstance(schema, StructType):
+ raise TypeError("schema should be StructType")
+ scala_datatype = self._ssql_ctx.parseDataType(schema.json())
+ df = self._ssql_ctx.load(source, scala_datatype, joptions)
+ return DataFrame(df, self)
+
+ def createExternalTable(self, tableName, path=None, source=None,
+ schema=None, **options):
+ """Creates an external table based on the dataset in a data source.
+
+ It returns the DataFrame associated with the external table.
+
+ The data source is specified by the `source` and a set of `options`.
+ If `source` is not specified, the default data source configured by
+ spark.sql.sources.default will be used.
+
+ Optionally, a schema can be provided as the schema of the returned DataFrame and
+ created external table.
+ """
+ if path is not None:
+ options["path"] = path
+ if source is None:
+ source = self.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ joptions = MapConverter().convert(options,
+ self._sc._gateway._gateway_client)
+ if schema is None:
+ df = self._ssql_ctx.createExternalTable(tableName, source, joptions)
+ else:
+ if not isinstance(schema, StructType):
+ raise TypeError("schema should be StructType")
+ scala_datatype = self._ssql_ctx.parseDataType(schema.json())
+ df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype,
+ joptions)
+ return DataFrame(df, self)
+
def sql(self, sqlQuery):
"""Return a L{DataFrame} representing the result of the given query.
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 04be65fe24..3eef0cc376 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -146,9 +146,75 @@ class DataFrame(object):
"""
self._jdf.insertInto(tableName, overwrite)
- def saveAsTable(self, tableName):
- """Creates a new table with the contents of this DataFrame."""
- self._jdf.saveAsTable(tableName)
+ def _java_save_mode(self, mode):
+ """Returns the Java save mode based on the Python save mode represented by a string.
+ """
+ jSaveMode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode
+ jmode = jSaveMode.ErrorIfExists
+ mode = mode.lower()
+ if mode == "append":
+ jmode = jSaveMode.Append
+ elif mode == "overwrite":
+ jmode = jSaveMode.Overwrite
+ elif mode == "ignore":
+ jmode = jSaveMode.Ignore
+ elif mode == "error":
+ pass
+ else:
+ raise ValueError(
+ "Only 'append', 'overwrite', 'ignore', and 'error' are acceptable save mode.")
+ return jmode
+
+ def saveAsTable(self, tableName, source=None, mode="append", **options):
+ """Saves the contents of the DataFrame to a data source as a table.
+
+ The data source is specified by the `source` and a set of `options`.
+ If `source` is not specified, the default data source configured by
+ spark.sql.sources.default will be used.
+
+ Additionally, mode is used to specify the behavior of the saveAsTable operation when
+ table already exists in the data source. There are four modes:
+
+ * append: Contents of this DataFrame are expected to be appended to existing table.
+ * overwrite: Data in the existing table is expected to be overwritten by the contents of \
+ this DataFrame.
+ * error: An exception is expected to be thrown.
+ * ignore: The save operation is expected to not save the contents of the DataFrame and \
+ to not change the existing table.
+ """
+ if source is None:
+ source = self.sql_ctx.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ jmode = self._java_save_mode(mode)
+ joptions = MapConverter().convert(options,
+ self.sql_ctx._sc._gateway._gateway_client)
+ self._jdf.saveAsTable(tableName, source, jmode, joptions)
+
+ def save(self, path=None, source=None, mode="append", **options):
+ """Saves the contents of the DataFrame to a data source.
+
+ The data source is specified by the `source` and a set of `options`.
+ If `source` is not specified, the default data source configured by
+ spark.sql.sources.default will be used.
+
+ Additionally, mode is used to specify the behavior of the save operation when
+ data already exists in the data source. There are four modes:
+
+ * append: Contents of this DataFrame are expected to be appended to existing data.
+ * overwrite: Existing data is expected to be overwritten by the contents of this DataFrame.
+ * error: An exception is expected to be thrown.
+ * ignore: The save operation is expected to not save the contents of the DataFrame and \
+ to not change the existing data.
+ """
+ if path is not None:
+ options["path"] = path
+ if source is None:
+ source = self.sql_ctx.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ jmode = self._java_save_mode(mode)
+ joptions = MapConverter().convert(options,
+ self._sc._gateway._gateway_client)
+ self._jdf.save(source, jmode, joptions)
def schema(self):
"""Returns the schema of this DataFrame (represented by
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index d25c6365ed..bc945091f7 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -34,10 +34,9 @@ if sys.version_info[:2] <= (2, 6):
else:
import unittest
-
-from pyspark.sql import SQLContext, Column
+from pyspark.sql import SQLContext, HiveContext, Column
from pyspark.sql.types import IntegerType, Row, ArrayType, StructType, StructField, \
- UserDefinedType, DoubleType, LongType
+ UserDefinedType, DoubleType, LongType, StringType
from pyspark.tests import ReusedPySparkTestCase
@@ -286,6 +285,37 @@ class SQLTests(ReusedPySparkTestCase):
self.assertTrue(95 < g.agg(Dsl.approxCountDistinct(df.key)).first()[0])
self.assertEqual(100, g.agg(Dsl.countDistinct(df.value)).first()[0])
+ def test_save_and_load(self):
+ df = self.df
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ df.save(tmpPath, "org.apache.spark.sql.json", "error")
+ actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
+ self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+
+ schema = StructType([StructField("value", StringType(), True)])
+ actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json", schema)
+ self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))
+
+ df.save(tmpPath, "org.apache.spark.sql.json", "overwrite")
+ actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
+ self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+
+ df.save(source="org.apache.spark.sql.json", mode="overwrite", path=tmpPath,
+ noUse="this options will not be used in save.")
+ actual = self.sqlCtx.load(source="org.apache.spark.sql.json", path=tmpPath,
+ noUse="this options will not be used in load.")
+ self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+
+ defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
+ actual = self.sqlCtx.load(path=tmpPath)
+ self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+ self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
+
+ shutil.rmtree(tmpPath)
+
def test_help_command(self):
# Regression test for SPARK-5464
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
@@ -296,5 +326,76 @@ class SQLTests(ReusedPySparkTestCase):
pydoc.render_doc(df.take(1))
+class HiveContextSQLTests(ReusedPySparkTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ ReusedPySparkTestCase.setUpClass()
+ cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+ os.unlink(cls.tempdir.name)
+ print "type", type(cls.sc)
+ print "type", type(cls.sc._jsc)
+ _scala_HiveContext =\
+ cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc())
+ cls.sqlCtx = HiveContext(cls.sc, _scala_HiveContext)
+ cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
+ rdd = cls.sc.parallelize(cls.testData)
+ cls.df = cls.sqlCtx.inferSchema(rdd)
+
+ @classmethod
+ def tearDownClass(cls):
+ ReusedPySparkTestCase.tearDownClass()
+ shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+
+ def test_save_and_load_table(self):
+ df = self.df
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "append", path=tmpPath)
+ actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath,
+ "org.apache.spark.sql.json")
+ self.assertTrue(
+ sorted(df.collect()) ==
+ sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
+ self.assertTrue(
+ sorted(df.collect()) ==
+ sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
+ self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+ self.sqlCtx.sql("DROP TABLE externalJsonTable")
+
+ df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "overwrite", path=tmpPath)
+ schema = StructType([StructField("value", StringType(), True)])
+ actual = self.sqlCtx.createExternalTable("externalJsonTable",
+ source="org.apache.spark.sql.json",
+ schema=schema, path=tmpPath,
+ noUse="this options will not be used")
+ self.assertTrue(
+ sorted(df.collect()) ==
+ sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
+ self.assertTrue(
+ sorted(df.select("value").collect()) ==
+ sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
+ self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))
+ self.sqlCtx.sql("DROP TABLE savedJsonTable")
+ self.sqlCtx.sql("DROP TABLE externalJsonTable")
+
+ defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
+ df.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
+ actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath)
+ self.assertTrue(
+ sorted(df.collect()) ==
+ sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
+ self.assertTrue(
+ sorted(df.collect()) ==
+ sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
+ self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+ self.sqlCtx.sql("DROP TABLE savedJsonTable")
+ self.sqlCtx.sql("DROP TABLE externalJsonTable")
+ self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
+
+ shutil.rmtree(tmpPath)
+
if __name__ == "__main__":
unittest.main()