aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/tests.py
diff options
context:
space:
mode:
authorAndrew Or <andrew@databricks.com>2016-04-29 16:41:13 -0700
committerAndrew Or <andrew@databricks.com>2016-04-29 16:41:13 -0700
commitd33e3d572ed7143f151f9c96fd08407f8de340f4 (patch)
tree6edbedd9e76ad883c9e9992f95cf3590cb4c1955 /python/pyspark/sql/tests.py
parent4ae9fe091c2cb8388c581093d62d3deaef40993e (diff)
downloadspark-d33e3d572ed7143f151f9c96fd08407f8de340f4.tar.gz
spark-d33e3d572ed7143f151f9c96fd08407f8de340f4.tar.bz2
spark-d33e3d572ed7143f151f9c96fd08407f8de340f4.zip
[SPARK-14988][PYTHON] SparkSession API follow-ups
## What changes were proposed in this pull request? Addresses comments in #12765. ## How was this patch tested? Python tests. Author: Andrew Or <andrew@databricks.com> Closes #12784 from andrewor14/python-followup.
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r--python/pyspark/sql/tests.py199
1 files changed, 197 insertions, 2 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 1d3dc159da..ea98206836 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -45,7 +45,7 @@ if sys.version_info[:2] <= (2, 6):
else:
import unittest
-from pyspark.sql import SQLContext, HiveContext, Column, Row
+from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
from pyspark.sql.types import *
from pyspark.sql.types import UserDefinedType, _infer_type
from pyspark.tests import ReusedPySparkTestCase
@@ -199,7 +199,8 @@ class SQLTests(ReusedPySparkTestCase):
ReusedPySparkTestCase.setUpClass()
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(cls.tempdir.name)
- cls.sqlCtx = SQLContext(cls.sc)
+ cls.sparkSession = SparkSession(cls.sc)
+ cls.sqlCtx = cls.sparkSession._wrapped
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
rdd = cls.sc.parallelize(cls.testData, 2)
cls.df = rdd.toDF()
@@ -1394,6 +1395,200 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(df.schema.simpleString(), "struct<value:int>")
self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
+ def test_conf(self):
+ spark = self.sparkSession
+ spark.setConf("bogo", "sipeo")
+ self.assertEqual(self.sparkSession.conf.get("bogo"), "sipeo")
+ spark.setConf("bogo", "ta")
+ self.assertEqual(spark.conf.get("bogo"), "ta")
+ self.assertEqual(spark.conf.get("bogo", "not.read"), "ta")
+ self.assertEqual(spark.conf.get("not.set", "ta"), "ta")
+ self.assertRaisesRegexp(Exception, "not.set", lambda: spark.conf.get("not.set"))
+ spark.conf.unset("bogo")
+ self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia")
+
+ def test_current_database(self):
+ spark = self.sparkSession
+ spark.catalog._reset()
+ self.assertEquals(spark.catalog.currentDatabase(), "default")
+ spark.sql("CREATE DATABASE some_db")
+ spark.catalog.setCurrentDatabase("some_db")
+ self.assertEquals(spark.catalog.currentDatabase(), "some_db")
+ self.assertRaisesRegexp(
+ AnalysisException,
+ "does_not_exist",
+ lambda: spark.catalog.setCurrentDatabase("does_not_exist"))
+
+ def test_list_databases(self):
+ spark = self.sparkSession
+ spark.catalog._reset()
+ databases = [db.name for db in spark.catalog.listDatabases()]
+ self.assertEquals(databases, ["default"])
+ spark.sql("CREATE DATABASE some_db")
+ databases = [db.name for db in spark.catalog.listDatabases()]
+ self.assertEquals(sorted(databases), ["default", "some_db"])
+
+ def test_list_tables(self):
+ from pyspark.sql.catalog import Table
+ spark = self.sparkSession
+ spark.catalog._reset()
+ spark.sql("CREATE DATABASE some_db")
+ self.assertEquals(spark.catalog.listTables(), [])
+ self.assertEquals(spark.catalog.listTables("some_db"), [])
+ spark.createDataFrame([(1, 1)]).registerTempTable("temp_tab")
+ spark.sql("CREATE TABLE tab1 (name STRING, age INT)")
+ spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT)")
+ tables = sorted(spark.catalog.listTables(), key=lambda t: t.name)
+ tablesDefault = sorted(spark.catalog.listTables("default"), key=lambda t: t.name)
+ tablesSomeDb = sorted(spark.catalog.listTables("some_db"), key=lambda t: t.name)
+ self.assertEquals(tables, tablesDefault)
+ self.assertEquals(len(tables), 2)
+ self.assertEquals(len(tablesSomeDb), 2)
+ self.assertEquals(tables[0], Table(
+ name="tab1",
+ database="default",
+ description=None,
+ tableType="MANAGED",
+ isTemporary=False))
+ self.assertEquals(tables[1], Table(
+ name="temp_tab",
+ database=None,
+ description=None,
+ tableType="TEMPORARY",
+ isTemporary=True))
+ self.assertEquals(tablesSomeDb[0], Table(
+ name="tab2",
+ database="some_db",
+ description=None,
+ tableType="MANAGED",
+ isTemporary=False))
+ self.assertEquals(tablesSomeDb[1], Table(
+ name="temp_tab",
+ database=None,
+ description=None,
+ tableType="TEMPORARY",
+ isTemporary=True))
+ self.assertRaisesRegexp(
+ AnalysisException,
+ "does_not_exist",
+ lambda: spark.catalog.listTables("does_not_exist"))
+
+ def test_list_functions(self):
+ from pyspark.sql.catalog import Function
+ spark = self.sparkSession
+ spark.catalog._reset()
+ spark.sql("CREATE DATABASE some_db")
+ functions = dict((f.name, f) for f in spark.catalog.listFunctions())
+ functionsDefault = dict((f.name, f) for f in spark.catalog.listFunctions("default"))
+ self.assertTrue(len(functions) > 200)
+ self.assertTrue("+" in functions)
+ self.assertTrue("like" in functions)
+ self.assertTrue("month" in functions)
+ self.assertTrue("to_unix_timestamp" in functions)
+ self.assertTrue("current_database" in functions)
+ self.assertEquals(functions["+"], Function(
+ name="+",
+ description=None,
+ className="org.apache.spark.sql.catalyst.expressions.Add",
+ isTemporary=True))
+ self.assertEquals(functions, functionsDefault)
+ spark.catalog.registerFunction("temp_func", lambda x: str(x))
+ spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'")
+ spark.sql("CREATE FUNCTION some_db.func2 AS 'org.apache.spark.data.bricks'")
+ newFunctions = dict((f.name, f) for f in spark.catalog.listFunctions())
+ newFunctionsSomeDb = dict((f.name, f) for f in spark.catalog.listFunctions("some_db"))
+ self.assertTrue(set(functions).issubset(set(newFunctions)))
+ self.assertTrue(set(functions).issubset(set(newFunctionsSomeDb)))
+ self.assertTrue("temp_func" in newFunctions)
+ self.assertTrue("func1" in newFunctions)
+ self.assertTrue("func2" not in newFunctions)
+ self.assertTrue("temp_func" in newFunctionsSomeDb)
+ self.assertTrue("func1" not in newFunctionsSomeDb)
+ self.assertTrue("func2" in newFunctionsSomeDb)
+ self.assertRaisesRegexp(
+ AnalysisException,
+ "does_not_exist",
+ lambda: spark.catalog.listFunctions("does_not_exist"))
+
+ def test_list_columns(self):
+ from pyspark.sql.catalog import Column
+ spark = self.sparkSession
+ spark.catalog._reset()
+ spark.sql("CREATE DATABASE some_db")
+ spark.sql("CREATE TABLE tab1 (name STRING, age INT)")
+ spark.sql("CREATE TABLE some_db.tab2 (nickname STRING, tolerance FLOAT)")
+ columns = sorted(spark.catalog.listColumns("tab1"), key=lambda c: c.name)
+ columnsDefault = sorted(spark.catalog.listColumns("tab1", "default"), key=lambda c: c.name)
+ self.assertEquals(columns, columnsDefault)
+ self.assertEquals(len(columns), 2)
+ self.assertEquals(columns[0], Column(
+ name="age",
+ description=None,
+ dataType="int",
+ nullable=True,
+ isPartition=False,
+ isBucket=False))
+ self.assertEquals(columns[1], Column(
+ name="name",
+ description=None,
+ dataType="string",
+ nullable=True,
+ isPartition=False,
+ isBucket=False))
+ columns2 = sorted(spark.catalog.listColumns("tab2", "some_db"), key=lambda c: c.name)
+ self.assertEquals(len(columns2), 2)
+ self.assertEquals(columns2[0], Column(
+ name="nickname",
+ description=None,
+ dataType="string",
+ nullable=True,
+ isPartition=False,
+ isBucket=False))
+ self.assertEquals(columns2[1], Column(
+ name="tolerance",
+ description=None,
+ dataType="float",
+ nullable=True,
+ isPartition=False,
+ isBucket=False))
+ self.assertRaisesRegexp(
+ AnalysisException,
+ "tab2",
+ lambda: spark.catalog.listColumns("tab2"))
+ self.assertRaisesRegexp(
+ AnalysisException,
+ "does_not_exist",
+ lambda: spark.catalog.listColumns("does_not_exist"))
+
+ def test_cache(self):
+ spark = self.sparkSession
+ spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("tab1")
+ spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("tab2")
+ self.assertFalse(spark.catalog.isCached("tab1"))
+ self.assertFalse(spark.catalog.isCached("tab2"))
+ spark.catalog.cacheTable("tab1")
+ self.assertTrue(spark.catalog.isCached("tab1"))
+ self.assertFalse(spark.catalog.isCached("tab2"))
+ spark.catalog.cacheTable("tab2")
+ spark.catalog.uncacheTable("tab1")
+ self.assertFalse(spark.catalog.isCached("tab1"))
+ self.assertTrue(spark.catalog.isCached("tab2"))
+ spark.catalog.clearCache()
+ self.assertFalse(spark.catalog.isCached("tab1"))
+ self.assertFalse(spark.catalog.isCached("tab2"))
+ self.assertRaisesRegexp(
+ AnalysisException,
+ "does_not_exist",
+ lambda: spark.catalog.isCached("does_not_exist"))
+ self.assertRaisesRegexp(
+ AnalysisException,
+ "does_not_exist",
+ lambda: spark.catalog.cacheTable("does_not_exist"))
+ self.assertRaisesRegexp(
+ AnalysisException,
+ "does_not_exist",
+ lambda: spark.catalog.uncacheTable("does_not_exist"))
+
class HiveContextSQLTests(ReusedPySparkTestCase):