diff options
author | Andrew Or <andrew@databricks.com> | 2016-04-29 16:41:13 -0700 |
---|---|---|
committer | Andrew Or <andrew@databricks.com> | 2016-04-29 16:41:13 -0700 |
commit | d33e3d572ed7143f151f9c96fd08407f8de340f4 (patch) | |
tree | 6edbedd9e76ad883c9e9992f95cf3590cb4c1955 /python/pyspark/sql/tests.py | |
parent | 4ae9fe091c2cb8388c581093d62d3deaef40993e (diff) | |
download | spark-d33e3d572ed7143f151f9c96fd08407f8de340f4.tar.gz spark-d33e3d572ed7143f151f9c96fd08407f8de340f4.tar.bz2 spark-d33e3d572ed7143f151f9c96fd08407f8de340f4.zip |
[SPARK-14988][PYTHON] SparkSession API follow-ups
## What changes were proposed in this pull request?
Addresses comments in #12765.
## How was this patch tested?
Python tests.
Author: Andrew Or <andrew@databricks.com>
Closes #12784 from andrewor14/python-followup.
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r-- | python/pyspark/sql/tests.py | 199 |
1 files changed, 197 insertions, 2 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1d3dc159da..ea98206836 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -45,7 +45,7 @@ if sys.version_info[:2] <= (2, 6): else: import unittest -from pyspark.sql import SQLContext, HiveContext, Column, Row +from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type from pyspark.tests import ReusedPySparkTestCase @@ -199,7 +199,8 @@ class SQLTests(ReusedPySparkTestCase): ReusedPySparkTestCase.setUpClass() cls.tempdir = tempfile.NamedTemporaryFile(delete=False) os.unlink(cls.tempdir.name) - cls.sqlCtx = SQLContext(cls.sc) + cls.sparkSession = SparkSession(cls.sc) + cls.sqlCtx = cls.sparkSession._wrapped cls.testData = [Row(key=i, value=str(i)) for i in range(100)] rdd = cls.sc.parallelize(cls.testData, 2) cls.df = rdd.toDF() @@ -1394,6 +1395,200 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(df.schema.simpleString(), "struct<value:int>") self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) + def test_conf(self): + spark = self.sparkSession + spark.setConf("bogo", "sipeo") + self.assertEqual(self.sparkSession.conf.get("bogo"), "sipeo") + spark.setConf("bogo", "ta") + self.assertEqual(spark.conf.get("bogo"), "ta") + self.assertEqual(spark.conf.get("bogo", "not.read"), "ta") + self.assertEqual(spark.conf.get("not.set", "ta"), "ta") + self.assertRaisesRegexp(Exception, "not.set", lambda: spark.conf.get("not.set")) + spark.conf.unset("bogo") + self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia") + + def test_current_database(self): + spark = self.sparkSession + spark.catalog._reset() + self.assertEquals(spark.catalog.currentDatabase(), "default") + spark.sql("CREATE DATABASE some_db") + spark.catalog.setCurrentDatabase("some_db") + self.assertEquals(spark.catalog.currentDatabase(), "some_db") + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.setCurrentDatabase("does_not_exist")) + + def test_list_databases(self): + spark = self.sparkSession + spark.catalog._reset() + databases = [db.name for db in spark.catalog.listDatabases()] + self.assertEquals(databases, ["default"]) + spark.sql("CREATE DATABASE some_db") + databases = [db.name for db in spark.catalog.listDatabases()] + self.assertEquals(sorted(databases), ["default", "some_db"]) + + def test_list_tables(self): + from pyspark.sql.catalog import Table + spark = self.sparkSession + spark.catalog._reset() + spark.sql("CREATE DATABASE some_db") + self.assertEquals(spark.catalog.listTables(), []) + self.assertEquals(spark.catalog.listTables("some_db"), []) + spark.createDataFrame([(1, 1)]).registerTempTable("temp_tab") + spark.sql("CREATE TABLE tab1 (name STRING, age INT)") + spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT)") + tables = sorted(spark.catalog.listTables(), key=lambda t: t.name) + tablesDefault = sorted(spark.catalog.listTables("default"), key=lambda t: t.name) + tablesSomeDb = sorted(spark.catalog.listTables("some_db"), key=lambda t: t.name) + self.assertEquals(tables, tablesDefault) + self.assertEquals(len(tables), 2) + self.assertEquals(len(tablesSomeDb), 2) + self.assertEquals(tables[0], Table( + name="tab1", + database="default", + description=None, + tableType="MANAGED", + isTemporary=False)) + self.assertEquals(tables[1], Table( + name="temp_tab", + database=None, + description=None, + tableType="TEMPORARY", + isTemporary=True)) + self.assertEquals(tablesSomeDb[0], Table( + name="tab2", + database="some_db", + description=None, + tableType="MANAGED", + isTemporary=False)) + self.assertEquals(tablesSomeDb[1], Table( + name="temp_tab", + database=None, + description=None, + tableType="TEMPORARY", + isTemporary=True)) + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.listTables("does_not_exist")) + + def test_list_functions(self): + from pyspark.sql.catalog import Function + spark = self.sparkSession + spark.catalog._reset() + spark.sql("CREATE DATABASE some_db") + functions = dict((f.name, f) for f in spark.catalog.listFunctions()) + functionsDefault = dict((f.name, f) for f in spark.catalog.listFunctions("default")) + self.assertTrue(len(functions) > 200) + self.assertTrue("+" in functions) + self.assertTrue("like" in functions) + self.assertTrue("month" in functions) + self.assertTrue("to_unix_timestamp" in functions) + self.assertTrue("current_database" in functions) + self.assertEquals(functions["+"], Function( + name="+", + description=None, + className="org.apache.spark.sql.catalyst.expressions.Add", + isTemporary=True)) + self.assertEquals(functions, functionsDefault) + spark.catalog.registerFunction("temp_func", lambda x: str(x)) + spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'") + spark.sql("CREATE FUNCTION some_db.func2 AS 'org.apache.spark.data.bricks'") + newFunctions = dict((f.name, f) for f in spark.catalog.listFunctions()) + newFunctionsSomeDb = dict((f.name, f) for f in spark.catalog.listFunctions("some_db")) + self.assertTrue(set(functions).issubset(set(newFunctions))) + self.assertTrue(set(functions).issubset(set(newFunctionsSomeDb))) + self.assertTrue("temp_func" in newFunctions) + self.assertTrue("func1" in newFunctions) + self.assertTrue("func2" not in newFunctions) + self.assertTrue("temp_func" in newFunctionsSomeDb) + self.assertTrue("func1" not in newFunctionsSomeDb) + self.assertTrue("func2" in newFunctionsSomeDb) + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.listFunctions("does_not_exist")) + + def test_list_columns(self): + from pyspark.sql.catalog import Column + spark = self.sparkSession + spark.catalog._reset() + spark.sql("CREATE DATABASE some_db") + spark.sql("CREATE TABLE tab1 (name STRING, age INT)") + spark.sql("CREATE TABLE some_db.tab2 (nickname STRING, tolerance FLOAT)") + columns = sorted(spark.catalog.listColumns("tab1"), key=lambda c: c.name) + columnsDefault = sorted(spark.catalog.listColumns("tab1", "default"), key=lambda c: c.name) + self.assertEquals(columns, columnsDefault) + self.assertEquals(len(columns), 2) + self.assertEquals(columns[0], Column( + name="age", + description=None, + dataType="int", + nullable=True, + isPartition=False, + isBucket=False)) + self.assertEquals(columns[1], Column( + name="name", + description=None, + dataType="string", + nullable=True, + isPartition=False, + isBucket=False)) + columns2 = sorted(spark.catalog.listColumns("tab2", "some_db"), key=lambda c: c.name) + self.assertEquals(len(columns2), 2) + self.assertEquals(columns2[0], Column( + name="nickname", + description=None, + dataType="string", + nullable=True, + isPartition=False, + isBucket=False)) + self.assertEquals(columns2[1], Column( + name="tolerance", + description=None, + dataType="float", + nullable=True, + isPartition=False, + isBucket=False)) + self.assertRaisesRegexp( + AnalysisException, + "tab2", + lambda: spark.catalog.listColumns("tab2")) + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.listColumns("does_not_exist")) + + def test_cache(self): + spark = self.sparkSession + spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("tab1") + spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("tab2") + self.assertFalse(spark.catalog.isCached("tab1")) + self.assertFalse(spark.catalog.isCached("tab2")) + spark.catalog.cacheTable("tab1") + self.assertTrue(spark.catalog.isCached("tab1")) + self.assertFalse(spark.catalog.isCached("tab2")) + spark.catalog.cacheTable("tab2") + spark.catalog.uncacheTable("tab1") + self.assertFalse(spark.catalog.isCached("tab1")) + self.assertTrue(spark.catalog.isCached("tab2")) + spark.catalog.clearCache() + self.assertFalse(spark.catalog.isCached("tab1")) + self.assertFalse(spark.catalog.isCached("tab2")) + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.isCached("does_not_exist")) + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.cacheTable("does_not_exist")) + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.uncacheTable("does_not_exist")) + class HiveContextSQLTests(ReusedPySparkTestCase): |