From aaf50d05c7616e4f8f16654b642500ae06cdd774 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 10 Feb 2015 17:29:52 -0800 Subject: [SPARK-5658][SQL] Finalize DDL and write support APIs https://issues.apache.org/jira/browse/SPARK-5658 Author: Yin Huai This patch had conflicts when merged, resolved by Committer: Michael Armbrust Closes #4446 from yhuai/writeSupportFollowup and squashes the following commits: f3a96f7 [Yin Huai] davies's comments. 225ff71 [Yin Huai] Use Scala TestHiveContext to initialize the Python HiveContext in Python tests. 2306f93 [Yin Huai] Style. 2091fcd [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup 537e28f [Yin Huai] Correctly clean up temp data. ae4649e [Yin Huai] Fix Python test. 609129c [Yin Huai] Doc format. 92b6659 [Yin Huai] Python doc and other minor updates. cbc717f [Yin Huai] Rename dataSourceName to source. d1c12d3 [Yin Huai] No need to delete the duplicate rule since it has been removed in master. 22cfa70 [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup d91ecb8 [Yin Huai] Fix test. 4c76d78 [Yin Huai] Simplify APIs. 3abc215 [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup 0832ce4 [Yin Huai] Fix test. 98e7cdb [Yin Huai] Python style. 2bf44ef [Yin Huai] Python APIs. c204967 [Yin Huai] Format a10223d [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup 9ff97d8 [Yin Huai] Add SaveMode to saveAsTable. 9b6e570 [Yin Huai] Update doc. c2be775 [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup 99950a2 [Yin Huai] Use Java enum for SaveMode. 4679665 [Yin Huai] Remove duplicate rule. 77d89dc [Yin Huai] Update doc. e04d908 [Yin Huai] Move import and add (Scala-specific) to scala APIs. cf5703d [Yin Huai] Add checkAnswer to Java tests. 7db95ff [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup 6dfd386 [Yin Huai] Add java test. f2f33ef [Yin Huai] Fix test. e702386 [Yin Huai] Apache header. b1e9b1b [Yin Huai] Format. ed4e1b4 [Yin Huai] Merge remote-tracking branch 'upstream/master' into writeSupportFollowup af9e9b3 [Yin Huai] DDL and write support API followup. 2a6213a [Yin Huai] Update API names. e6a0b77 [Yin Huai] Update test. 43bae01 [Yin Huai] Remove createTable from HiveContext. 5ffc372 [Yin Huai] Add more load APIs to SQLContext. 5390743 [Yin Huai] Add more save APIs to DataFrame. --- python/pyspark/sql/tests.py | 107 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 104 insertions(+), 3 deletions(-) (limited to 'python/pyspark/sql/tests.py') diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d25c6365ed..bc945091f7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -34,10 +34,9 @@ if sys.version_info[:2] <= (2, 6): else: import unittest - -from pyspark.sql import SQLContext, Column +from pyspark.sql import SQLContext, HiveContext, Column from pyspark.sql.types import IntegerType, Row, ArrayType, StructType, StructField, \ - UserDefinedType, DoubleType, LongType + UserDefinedType, DoubleType, LongType, StringType from pyspark.tests import ReusedPySparkTestCase @@ -286,6 +285,37 @@ class SQLTests(ReusedPySparkTestCase): self.assertTrue(95 < g.agg(Dsl.approxCountDistinct(df.key)).first()[0]) self.assertEqual(100, g.agg(Dsl.countDistinct(df.value)).first()[0]) + def test_save_and_load(self): + df = self.df + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.save(tmpPath, "org.apache.spark.sql.json", "error") + actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json") + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + + schema = StructType([StructField("value", StringType(), True)]) + actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json", schema) + self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect())) + + df.save(tmpPath, "org.apache.spark.sql.json", "overwrite") + actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json") + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + + df.save(source="org.apache.spark.sql.json", mode="overwrite", path=tmpPath, + noUse="this options will not be used in save.") + actual = self.sqlCtx.load(source="org.apache.spark.sql.json", path=tmpPath, + noUse="this options will not be used in load.") + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + + defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + actual = self.sqlCtx.load(path=tmpPath) + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + + shutil.rmtree(tmpPath) + def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) @@ -296,5 +326,76 @@ class SQLTests(ReusedPySparkTestCase): pydoc.render_doc(df.take(1)) +class HiveContextSQLTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + print "type", type(cls.sc) + print "type", type(cls.sc._jsc) + _scala_HiveContext =\ + cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc()) + cls.sqlCtx = HiveContext(cls.sc, _scala_HiveContext) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + rdd = cls.sc.parallelize(cls.testData) + cls.df = cls.sqlCtx.inferSchema(rdd) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name, ignore_errors=True) + + def test_save_and_load_table(self): + df = self.df + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "append", path=tmpPath) + actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath, + "org.apache.spark.sql.json") + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + self.sqlCtx.sql("DROP TABLE externalJsonTable") + + df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "overwrite", path=tmpPath) + schema = StructType([StructField("value", StringType(), True)]) + actual = self.sqlCtx.createExternalTable("externalJsonTable", + source="org.apache.spark.sql.json", + schema=schema, path=tmpPath, + noUse="this options will not be used") + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertTrue( + sorted(df.select("value").collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect())) + self.sqlCtx.sql("DROP TABLE savedJsonTable") + self.sqlCtx.sql("DROP TABLE externalJsonTable") + + defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + df.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite") + actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath) + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertTrue( + sorted(df.collect()) == + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + self.sqlCtx.sql("DROP TABLE savedJsonTable") + self.sqlCtx.sql("DROP TABLE externalJsonTable") + self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + + shutil.rmtree(tmpPath) + if __name__ == "__main__": unittest.main() -- cgit v1.2.3