aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorAndrew Or <andrew@databricks.com>2016-04-29 16:41:13 -0700
committerAndrew Or <andrew@databricks.com>2016-04-29 16:41:13 -0700
commitd33e3d572ed7143f151f9c96fd08407f8de340f4 (patch)
tree6edbedd9e76ad883c9e9992f95cf3590cb4c1955 /python
parent4ae9fe091c2cb8388c581093d62d3deaef40993e (diff)
downloadspark-d33e3d572ed7143f151f9c96fd08407f8de340f4.tar.gz
spark-d33e3d572ed7143f151f9c96fd08407f8de340f4.tar.bz2
spark-d33e3d572ed7143f151f9c96fd08407f8de340f4.zip
[SPARK-14988][PYTHON] SparkSession API follow-ups
## What changes were proposed in this pull request? Addresses comments in #12765. ## How was this patch tested? Python tests. Author: Andrew Or <andrew@databricks.com> Closes #12784 from andrewor14/python-followup.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/catalog.py168
-rw-r--r--python/pyspark/sql/conf.py58
-rw-r--r--python/pyspark/sql/context.py8
-rw-r--r--python/pyspark/sql/session.py4
-rw-r--r--python/pyspark/sql/tests.py199
5 files changed, 228 insertions, 209 deletions
diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py
index 4f9238374a..9cfdd0a99f 100644
--- a/python/pyspark/sql/catalog.py
+++ b/python/pyspark/sql/catalog.py
@@ -45,45 +45,19 @@ class Catalog(object):
@ignore_unicode_prefix
@since(2.0)
def currentDatabase(self):
- """Returns the current default database in this session.
-
- >>> spark.catalog._reset()
- >>> spark.catalog.currentDatabase()
- u'default'
- """
+ """Returns the current default database in this session."""
return self._jcatalog.currentDatabase()
@ignore_unicode_prefix
@since(2.0)
def setCurrentDatabase(self, dbName):
- """Sets the current default database in this session.
-
- >>> spark.catalog._reset()
- >>> spark.sql("CREATE DATABASE some_db")
- DataFrame[]
- >>> spark.catalog.setCurrentDatabase("some_db")
- >>> spark.catalog.currentDatabase()
- u'some_db'
- >>> spark.catalog.setCurrentDatabase("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- AnalysisException: ...
- """
+ """Sets the current default database in this session."""
return self._jcatalog.setCurrentDatabase(dbName)
@ignore_unicode_prefix
@since(2.0)
def listDatabases(self):
- """Returns a list of databases available across all sessions.
-
- >>> spark.catalog._reset()
- >>> [db.name for db in spark.catalog.listDatabases()]
- [u'default']
- >>> spark.sql("CREATE DATABASE some_db")
- DataFrame[]
- >>> [db.name for db in spark.catalog.listDatabases()]
- [u'default', u'some_db']
- """
+ """Returns a list of databases available across all sessions."""
iter = self._jcatalog.listDatabases().toLocalIterator()
databases = []
while iter.hasNext():
@@ -101,31 +75,6 @@ class Catalog(object):
If no database is specified, the current database is used.
This includes all temporary tables.
-
- >>> spark.catalog._reset()
- >>> spark.sql("CREATE DATABASE some_db")
- DataFrame[]
- >>> spark.catalog.listTables()
- []
- >>> spark.catalog.listTables("some_db")
- []
- >>> spark.createDataFrame([(1, 1)]).registerTempTable("my_temp_tab")
- >>> spark.sql("CREATE TABLE my_tab1 (name STRING, age INT)")
- DataFrame[]
- >>> spark.sql("CREATE TABLE some_db.my_tab2 (name STRING, age INT)")
- DataFrame[]
- >>> spark.catalog.listTables()
- [Table(name=u'my_tab1', database=u'default', description=None, tableType=u'MANAGED',
- isTemporary=False), Table(name=u'my_temp_tab', database=None, description=None,
- tableType=u'TEMPORARY', isTemporary=True)]
- >>> spark.catalog.listTables("some_db")
- [Table(name=u'my_tab2', database=u'some_db', description=None, tableType=u'MANAGED',
- isTemporary=False), Table(name=u'my_temp_tab', database=None, description=None,
- tableType=u'TEMPORARY', isTemporary=True)]
- >>> spark.catalog.listTables("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- AnalysisException: ...
"""
if dbName is None:
dbName = self.currentDatabase()
@@ -148,28 +97,6 @@ class Catalog(object):
If no database is specified, the current database is used.
This includes all temporary functions.
-
- >>> spark.catalog._reset()
- >>> spark.sql("CREATE DATABASE my_db")
- DataFrame[]
- >>> funcNames = set(f.name for f in spark.catalog.listFunctions())
- >>> set(["+", "floor", "to_unix_timestamp", "current_database"]).issubset(funcNames)
- True
- >>> spark.sql("CREATE FUNCTION my_func1 AS 'org.apache.spark.whatever'")
- DataFrame[]
- >>> spark.sql("CREATE FUNCTION my_db.my_func2 AS 'org.apache.spark.whatever'")
- DataFrame[]
- >>> spark.catalog.registerFunction("temp_func", lambda x: str(x))
- >>> newFuncNames = set(f.name for f in spark.catalog.listFunctions()) - funcNames
- >>> newFuncNamesDb = set(f.name for f in spark.catalog.listFunctions("my_db")) - funcNames
- >>> sorted(list(newFuncNames))
- [u'my_func1', u'temp_func']
- >>> sorted(list(newFuncNamesDb))
- [u'my_func2', u'temp_func']
- >>> spark.catalog.listFunctions("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- AnalysisException: ...
"""
if dbName is None:
dbName = self.currentDatabase()
@@ -193,26 +120,6 @@ class Catalog(object):
Note: the order of arguments here is different from that of its JVM counterpart
because Python does not support method overloading.
-
- >>> spark.catalog._reset()
- >>> spark.sql("CREATE DATABASE some_db")
- DataFrame[]
- >>> spark.sql("CREATE TABLE my_tab1 (name STRING, age INT)")
- DataFrame[]
- >>> spark.sql("CREATE TABLE some_db.my_tab2 (nickname STRING, tolerance FLOAT)")
- DataFrame[]
- >>> spark.catalog.listColumns("my_tab1")
- [Column(name=u'name', description=None, dataType=u'string', nullable=True,
- isPartition=False, isBucket=False), Column(name=u'age', description=None,
- dataType=u'int', nullable=True, isPartition=False, isBucket=False)]
- >>> spark.catalog.listColumns("my_tab2", "some_db")
- [Column(name=u'nickname', description=None, dataType=u'string', nullable=True,
- isPartition=False, isBucket=False), Column(name=u'tolerance', description=None,
- dataType=u'float', nullable=True, isPartition=False, isBucket=False)]
- >>> spark.catalog.listColumns("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- AnalysisException: ...
"""
if dbName is None:
dbName = self.currentDatabase()
@@ -247,7 +154,7 @@ class Catalog(object):
if path is not None:
options["path"] = path
if source is None:
- source = self._sparkSession.getConf(
+ source = self._sparkSession.conf.get(
"spark.sql.sources.default", "org.apache.spark.sql.parquet")
if schema is None:
df = self._jcatalog.createExternalTable(tableName, source, options)
@@ -275,16 +182,16 @@ class Catalog(object):
self._jcatalog.dropTempTable(tableName)
@since(2.0)
- def registerDataFrameAsTable(self, df, tableName):
+ def registerTable(self, df, tableName):
"""Registers the given :class:`DataFrame` as a temporary table in the catalog.
>>> df = spark.createDataFrame([(2, 1), (3, 1)])
- >>> spark.catalog.registerDataFrameAsTable(df, "my_cool_table")
+ >>> spark.catalog.registerTable(df, "my_cool_table")
>>> spark.table("my_cool_table").collect()
[Row(_1=2, _2=1), Row(_1=3, _2=1)]
"""
if isinstance(df, DataFrame):
- self._jsparkSession.registerDataFrameAsTable(df._jdf, tableName)
+ self._jsparkSession.registerTable(df._jdf, tableName)
else:
raise ValueError("Can only register DataFrame as table")
@@ -321,75 +228,22 @@ class Catalog(object):
@since(2.0)
def isCached(self, tableName):
- """Returns true if the table is currently cached in-memory.
-
- >>> spark.catalog._reset()
- >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab")
- >>> spark.catalog.isCached("my_tab")
- False
- >>> spark.catalog.cacheTable("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- AnalysisException: ...
- """
+ """Returns true if the table is currently cached in-memory."""
return self._jcatalog.isCached(tableName)
@since(2.0)
def cacheTable(self, tableName):
- """Caches the specified table in-memory.
-
- >>> spark.catalog._reset()
- >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab")
- >>> spark.catalog.isCached("my_tab")
- False
- >>> spark.catalog.cacheTable("my_tab")
- >>> spark.catalog.isCached("my_tab")
- True
- >>> spark.catalog.cacheTable("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- AnalysisException: ...
- """
+ """Caches the specified table in-memory."""
self._jcatalog.cacheTable(tableName)
@since(2.0)
def uncacheTable(self, tableName):
- """Removes the specified table from the in-memory cache.
-
- >>> spark.catalog._reset()
- >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab")
- >>> spark.catalog.cacheTable("my_tab")
- >>> spark.catalog.isCached("my_tab")
- True
- >>> spark.catalog.uncacheTable("my_tab")
- >>> spark.catalog.isCached("my_tab")
- False
- >>> spark.catalog.uncacheTable("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- AnalysisException: ...
- """
+ """Removes the specified table from the in-memory cache."""
self._jcatalog.uncacheTable(tableName)
@since(2.0)
def clearCache(self):
- """Removes all cached tables from the in-memory cache.
-
- >>> spark.catalog._reset()
- >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab1")
- >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab2")
- >>> spark.catalog.cacheTable("my_tab1")
- >>> spark.catalog.cacheTable("my_tab2")
- >>> spark.catalog.isCached("my_tab1")
- True
- >>> spark.catalog.isCached("my_tab2")
- True
- >>> spark.catalog.clearCache()
- >>> spark.catalog.isCached("my_tab1")
- False
- >>> spark.catalog.isCached("my_tab2")
- False
- """
+ """Removes all cached tables from the in-memory cache."""
self._jcatalog.clearCache()
def _reset(self):
diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py
index 1d9f052e25..7428c91991 100644
--- a/python/pyspark/sql/conf.py
+++ b/python/pyspark/sql/conf.py
@@ -33,64 +33,34 @@ class RuntimeConfig(object):
@ignore_unicode_prefix
@since(2.0)
def set(self, key, value):
- """Sets the given Spark runtime configuration property.
-
- >>> spark.conf.set("garble", "marble")
- >>> spark.getConf("garble")
- u'marble'
- """
+ """Sets the given Spark runtime configuration property."""
self._jconf.set(key, value)
@ignore_unicode_prefix
@since(2.0)
- def get(self, key):
+ def get(self, key, default=None):
"""Returns the value of Spark runtime configuration property for the given key,
assuming it is set.
-
- >>> spark.setConf("bogo", "sipeo")
- >>> spark.conf.get("bogo")
- u'sipeo'
- >>> spark.conf.get("definitely.not.set") # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- Py4JJavaError: ...
- """
- return self._jconf.get(key)
-
- @ignore_unicode_prefix
- @since(2.0)
- def getOption(self, key):
- """Returns the value of Spark runtime configuration property for the given key,
- or None if it is not set.
-
- >>> spark.setConf("bogo", "sipeo")
- >>> spark.conf.getOption("bogo")
- u'sipeo'
- >>> spark.conf.getOption("definitely.not.set") is None
- True
"""
- iter = self._jconf.getOption(key).iterator()
- if iter.hasNext():
- return iter.next()
+ self._checkType(key, "key")
+ if default is None:
+ return self._jconf.get(key)
else:
- return None
+ self._checkType(default, "default")
+ return self._jconf.get(key, default)
@ignore_unicode_prefix
@since(2.0)
def unset(self, key):
- """Resets the configuration property for the given key.
-
- >>> spark.setConf("armado", "larmado")
- >>> spark.getConf("armado")
- u'larmado'
- >>> spark.conf.unset("armado")
- >>> spark.getConf("armado") # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- Py4JJavaError: ...
- """
+ """Resets the configuration property for the given key."""
self._jconf.unset(key)
+ def _checkType(self, obj, identifier):
+ """Assert that an object is of type str."""
+ if not isinstance(obj, str) and not isinstance(obj, unicode):
+ raise TypeError("expected %s '%s' to be a string (was '%s')" %
+ (identifier, obj, type(obj).__name__))
+
def _test():
import os
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 94856c245b..417d719c35 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -127,10 +127,10 @@ class SQLContext(object):
>>> sqlContext.getConf("spark.sql.shuffle.partitions")
u'200'
- >>> sqlContext.getConf("spark.sql.shuffle.partitions", "10")
+ >>> sqlContext.getConf("spark.sql.shuffle.partitions", u"10")
u'10'
- >>> sqlContext.setConf("spark.sql.shuffle.partitions", "50")
- >>> sqlContext.getConf("spark.sql.shuffle.partitions", "10")
+ >>> sqlContext.setConf("spark.sql.shuffle.partitions", u"50")
+ >>> sqlContext.getConf("spark.sql.shuffle.partitions", u"10")
u'50'
"""
return self.sparkSession.getConf(key, defaultValue)
@@ -301,7 +301,7 @@ class SQLContext(object):
>>> sqlContext.registerDataFrameAsTable(df, "table1")
"""
- self.sparkSession.catalog.registerDataFrameAsTable(df, tableName)
+ self.sparkSession.catalog.registerTable(df, tableName)
@since(1.6)
def dropTempTable(self, tableName):
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index b3bc8961b8..c2452613ba 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -443,7 +443,7 @@ class SparkSession(object):
:return: :class:`DataFrame`
- >>> spark.catalog.registerDataFrameAsTable(df, "table1")
+ >>> spark.catalog.registerTable(df, "table1")
>>> df2 = spark.sql("SELECT field1 AS f1, field2 as f2 from table1")
>>> df2.collect()
[Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
@@ -456,7 +456,7 @@ class SparkSession(object):
:return: :class:`DataFrame`
- >>> spark.catalog.registerDataFrameAsTable(df, "table1")
+ >>> spark.catalog.registerTable(df, "table1")
>>> df2 = spark.table("table1")
>>> sorted(df.collect()) == sorted(df2.collect())
True
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 1d3dc159da..ea98206836 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -45,7 +45,7 @@ if sys.version_info[:2] <= (2, 6):
else:
import unittest
-from pyspark.sql import SQLContext, HiveContext, Column, Row
+from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
from pyspark.sql.types import *
from pyspark.sql.types import UserDefinedType, _infer_type
from pyspark.tests import ReusedPySparkTestCase
@@ -199,7 +199,8 @@ class SQLTests(ReusedPySparkTestCase):
ReusedPySparkTestCase.setUpClass()
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(cls.tempdir.name)
- cls.sqlCtx = SQLContext(cls.sc)
+ cls.sparkSession = SparkSession(cls.sc)
+ cls.sqlCtx = cls.sparkSession._wrapped
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
rdd = cls.sc.parallelize(cls.testData, 2)
cls.df = rdd.toDF()
@@ -1394,6 +1395,200 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(df.schema.simpleString(), "struct<value:int>")
self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
+ def test_conf(self):
+ spark = self.sparkSession
+ spark.setConf("bogo", "sipeo")
+ self.assertEqual(self.sparkSession.conf.get("bogo"), "sipeo")
+ spark.setConf("bogo", "ta")
+ self.assertEqual(spark.conf.get("bogo"), "ta")
+ self.assertEqual(spark.conf.get("bogo", "not.read"), "ta")
+ self.assertEqual(spark.conf.get("not.set", "ta"), "ta")
+ self.assertRaisesRegexp(Exception, "not.set", lambda: spark.conf.get("not.set"))
+ spark.conf.unset("bogo")
+ self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia")
+
+ def test_current_database(self):
+ spark = self.sparkSession
+ spark.catalog._reset()
+ self.assertEquals(spark.catalog.currentDatabase(), "default")
+ spark.sql("CREATE DATABASE some_db")
+ spark.catalog.setCurrentDatabase("some_db")
+ self.assertEquals(spark.catalog.currentDatabase(), "some_db")
+ self.assertRaisesRegexp(
+ AnalysisException,
+ "does_not_exist",
+ lambda: spark.catalog.setCurrentDatabase("does_not_exist"))
+
+ def test_list_databases(self):
+ spark = self.sparkSession
+ spark.catalog._reset()
+ databases = [db.name for db in spark.catalog.listDatabases()]
+ self.assertEquals(databases, ["default"])
+ spark.sql("CREATE DATABASE some_db")
+ databases = [db.name for db in spark.catalog.listDatabases()]
+ self.assertEquals(sorted(databases), ["default", "some_db"])
+
+ def test_list_tables(self):
+ from pyspark.sql.catalog import Table
+ spark = self.sparkSession
+ spark.catalog._reset()
+ spark.sql("CREATE DATABASE some_db")
+ self.assertEquals(spark.catalog.listTables(), [])
+ self.assertEquals(spark.catalog.listTables("some_db"), [])
+ spark.createDataFrame([(1, 1)]).registerTempTable("temp_tab")
+ spark.sql("CREATE TABLE tab1 (name STRING, age INT)")
+ spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT)")
+ tables = sorted(spark.catalog.listTables(), key=lambda t: t.name)
+ tablesDefault = sorted(spark.catalog.listTables("default"), key=lambda t: t.name)
+ tablesSomeDb = sorted(spark.catalog.listTables("some_db"), key=lambda t: t.name)
+ self.assertEquals(tables, tablesDefault)
+ self.assertEquals(len(tables), 2)
+ self.assertEquals(len(tablesSomeDb), 2)
+ self.assertEquals(tables[0], Table(
+ name="tab1",
+ database="default",
+ description=None,
+ tableType="MANAGED",
+ isTemporary=False))
+ self.assertEquals(tables[1], Table(
+ name="temp_tab",
+ database=None,
+ description=None,
+ tableType="TEMPORARY",
+ isTemporary=True))
+ self.assertEquals(tablesSomeDb[0], Table(
+ name="tab2",
+ database="some_db",
+ description=None,
+ tableType="MANAGED",
+ isTemporary=False))
+ self.assertEquals(tablesSomeDb[1], Table(
+ name="temp_tab",
+ database=None,
+ description=None,
+ tableType="TEMPORARY",
+ isTemporary=True))
+ self.assertRaisesRegexp(
+ AnalysisException,
+ "does_not_exist",
+ lambda: spark.catalog.listTables("does_not_exist"))
+
+ def test_list_functions(self):
+ from pyspark.sql.catalog import Function
+ spark = self.sparkSession
+ spark.catalog._reset()
+ spark.sql("CREATE DATABASE some_db")
+ functions = dict((f.name, f) for f in spark.catalog.listFunctions())
+ functionsDefault = dict((f.name, f) for f in spark.catalog.listFunctions("default"))
+ self.assertTrue(len(functions) > 200)
+ self.assertTrue("+" in functions)
+ self.assertTrue("like" in functions)
+ self.assertTrue("month" in functions)
+ self.assertTrue("to_unix_timestamp" in functions)
+ self.assertTrue("current_database" in functions)
+ self.assertEquals(functions["+"], Function(
+ name="+",
+ description=None,
+ className="org.apache.spark.sql.catalyst.expressions.Add",
+ isTemporary=True))
+ self.assertEquals(functions, functionsDefault)
+ spark.catalog.registerFunction("temp_func", lambda x: str(x))
+ spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'")
+ spark.sql("CREATE FUNCTION some_db.func2 AS 'org.apache.spark.data.bricks'")
+ newFunctions = dict((f.name, f) for f in spark.catalog.listFunctions())
+ newFunctionsSomeDb = dict((f.name, f) for f in spark.catalog.listFunctions("some_db"))
+ self.assertTrue(set(functions).issubset(set(newFunctions)))
+ self.assertTrue(set(functions).issubset(set(newFunctionsSomeDb)))
+ self.assertTrue("temp_func" in newFunctions)
+ self.assertTrue("func1" in newFunctions)
+ self.assertTrue("func2" not in newFunctions)
+ self.assertTrue("temp_func" in newFunctionsSomeDb)
+ self.assertTrue("func1" not in newFunctionsSomeDb)
+ self.assertTrue("func2" in newFunctionsSomeDb)
+ self.assertRaisesRegexp(
+ AnalysisException,
+ "does_not_exist",
+ lambda: spark.catalog.listFunctions("does_not_exist"))
+
+ def test_list_columns(self):
+ from pyspark.sql.catalog import Column
+ spark = self.sparkSession
+ spark.catalog._reset()
+ spark.sql("CREATE DATABASE some_db")
+ spark.sql("CREATE TABLE tab1 (name STRING, age INT)")
+ spark.sql("CREATE TABLE some_db.tab2 (nickname STRING, tolerance FLOAT)")
+ columns = sorted(spark.catalog.listColumns("tab1"), key=lambda c: c.name)
+ columnsDefault = sorted(spark.catalog.listColumns("tab1", "default"), key=lambda c: c.name)
+ self.assertEquals(columns, columnsDefault)
+ self.assertEquals(len(columns), 2)
+ self.assertEquals(columns[0], Column(
+ name="age",
+ description=None,
+ dataType="int",
+ nullable=True,
+ isPartition=False,
+ isBucket=False))
+ self.assertEquals(columns[1], Column(
+ name="name",
+ description=None,
+ dataType="string",
+ nullable=True,
+ isPartition=False,
+ isBucket=False))
+ columns2 = sorted(spark.catalog.listColumns("tab2", "some_db"), key=lambda c: c.name)
+ self.assertEquals(len(columns2), 2)
+ self.assertEquals(columns2[0], Column(
+ name="nickname",
+ description=None,
+ dataType="string",
+ nullable=True,
+ isPartition=False,
+ isBucket=False))
+ self.assertEquals(columns2[1], Column(
+ name="tolerance",
+ description=None,
+ dataType="float",
+ nullable=True,
+ isPartition=False,
+ isBucket=False))
+ self.assertRaisesRegexp(
+ AnalysisException,
+ "tab2",
+ lambda: spark.catalog.listColumns("tab2"))
+ self.assertRaisesRegexp(
+ AnalysisException,
+ "does_not_exist",
+ lambda: spark.catalog.listColumns("does_not_exist"))
+
+ def test_cache(self):
+ spark = self.sparkSession
+ spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("tab1")
+ spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("tab2")
+ self.assertFalse(spark.catalog.isCached("tab1"))
+ self.assertFalse(spark.catalog.isCached("tab2"))
+ spark.catalog.cacheTable("tab1")
+ self.assertTrue(spark.catalog.isCached("tab1"))
+ self.assertFalse(spark.catalog.isCached("tab2"))
+ spark.catalog.cacheTable("tab2")
+ spark.catalog.uncacheTable("tab1")
+ self.assertFalse(spark.catalog.isCached("tab1"))
+ self.assertTrue(spark.catalog.isCached("tab2"))
+ spark.catalog.clearCache()
+ self.assertFalse(spark.catalog.isCached("tab1"))
+ self.assertFalse(spark.catalog.isCached("tab2"))
+ self.assertRaisesRegexp(
+ AnalysisException,
+ "does_not_exist",
+ lambda: spark.catalog.isCached("does_not_exist"))
+ self.assertRaisesRegexp(
+ AnalysisException,
+ "does_not_exist",
+ lambda: spark.catalog.cacheTable("does_not_exist"))
+ self.assertRaisesRegexp(
+ AnalysisException,
+ "does_not_exist",
+ lambda: spark.catalog.uncacheTable("does_not_exist"))
+
class HiveContextSQLTests(ReusedPySparkTestCase):