aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql
diff options
context:
space:
mode:
authorAndrew Or <andrew@databricks.com>2016-04-29 09:34:10 -0700
committerAndrew Or <andrew@databricks.com>2016-04-29 09:34:10 -0700
commita7d0fedc940721d09350f2e57ae85591e0a3d90e (patch)
tree5f96e980e810cd13f36658ed052a1e987c5d261c /python/pyspark/sql
parent7feeb82cb7f462e44f7e698c7c3b6ac3a77aade4 (diff)
downloadspark-a7d0fedc940721d09350f2e57ae85591e0a3d90e.tar.gz
spark-a7d0fedc940721d09350f2e57ae85591e0a3d90e.tar.bz2
spark-a7d0fedc940721d09350f2e57ae85591e0a3d90e.zip
[SPARK-14988][PYTHON] SparkSession catalog and conf API
## What changes were proposed in this pull request? The `catalog` and `conf` APIs were exposed in `SparkSession` in #12713 and #12669. This patch adds those to the python API. ## How was this patch tested? Python tests. Author: Andrew Or <andrew@databricks.com> Closes #12765 from andrewor14/python-spark-session-more.
Diffstat (limited to 'python/pyspark/sql')
-rw-r--r--python/pyspark/sql/catalog.py426
-rw-r--r--python/pyspark/sql/conf.py114
-rw-r--r--python/pyspark/sql/context.py11
-rw-r--r--python/pyspark/sql/session.py139
4 files changed, 605 insertions, 85 deletions
diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py
new file mode 100644
index 0000000000..4f9238374a
--- /dev/null
+++ b/python/pyspark/sql/catalog.py
@@ -0,0 +1,426 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from collections import namedtuple
+
+from pyspark import since
+from pyspark.rdd import ignore_unicode_prefix
+from pyspark.sql.dataframe import DataFrame
+from pyspark.sql.functions import UserDefinedFunction
+from pyspark.sql.types import IntegerType, StringType, StructType
+
+
+Database = namedtuple("Database", "name description locationUri")
+Table = namedtuple("Table", "name database description tableType isTemporary")
+Column = namedtuple("Column", "name description dataType nullable isPartition isBucket")
+Function = namedtuple("Function", "name description className isTemporary")
+
+
+class Catalog(object):
+ """User-facing catalog API, accessible through `SparkSession.catalog`.
+
+ This is a thin wrapper around its Scala implementation org.apache.spark.sql.catalog.Catalog.
+ """
+
+ def __init__(self, sparkSession):
+ """Create a new Catalog that wraps the underlying JVM object."""
+ self._sparkSession = sparkSession
+ self._jsparkSession = sparkSession._jsparkSession
+ self._jcatalog = sparkSession._jsparkSession.catalog()
+
+ @ignore_unicode_prefix
+ @since(2.0)
+ def currentDatabase(self):
+ """Returns the current default database in this session.
+
+ >>> spark.catalog._reset()
+ >>> spark.catalog.currentDatabase()
+ u'default'
+ """
+ return self._jcatalog.currentDatabase()
+
+ @ignore_unicode_prefix
+ @since(2.0)
+ def setCurrentDatabase(self, dbName):
+ """Sets the current default database in this session.
+
+ >>> spark.catalog._reset()
+ >>> spark.sql("CREATE DATABASE some_db")
+ DataFrame[]
+ >>> spark.catalog.setCurrentDatabase("some_db")
+ >>> spark.catalog.currentDatabase()
+ u'some_db'
+ >>> spark.catalog.setCurrentDatabase("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ AnalysisException: ...
+ """
+ return self._jcatalog.setCurrentDatabase(dbName)
+
+ @ignore_unicode_prefix
+ @since(2.0)
+ def listDatabases(self):
+ """Returns a list of databases available across all sessions.
+
+ >>> spark.catalog._reset()
+ >>> [db.name for db in spark.catalog.listDatabases()]
+ [u'default']
+ >>> spark.sql("CREATE DATABASE some_db")
+ DataFrame[]
+ >>> [db.name for db in spark.catalog.listDatabases()]
+ [u'default', u'some_db']
+ """
+ iter = self._jcatalog.listDatabases().toLocalIterator()
+ databases = []
+ while iter.hasNext():
+ jdb = iter.next()
+ databases.append(Database(
+ name=jdb.name(),
+ description=jdb.description(),
+ locationUri=jdb.locationUri()))
+ return databases
+
+ @ignore_unicode_prefix
+ @since(2.0)
+ def listTables(self, dbName=None):
+ """Returns a list of tables in the specified database.
+
+ If no database is specified, the current database is used.
+ This includes all temporary tables.
+
+ >>> spark.catalog._reset()
+ >>> spark.sql("CREATE DATABASE some_db")
+ DataFrame[]
+ >>> spark.catalog.listTables()
+ []
+ >>> spark.catalog.listTables("some_db")
+ []
+ >>> spark.createDataFrame([(1, 1)]).registerTempTable("my_temp_tab")
+ >>> spark.sql("CREATE TABLE my_tab1 (name STRING, age INT)")
+ DataFrame[]
+ >>> spark.sql("CREATE TABLE some_db.my_tab2 (name STRING, age INT)")
+ DataFrame[]
+ >>> spark.catalog.listTables()
+ [Table(name=u'my_tab1', database=u'default', description=None, tableType=u'MANAGED',
+ isTemporary=False), Table(name=u'my_temp_tab', database=None, description=None,
+ tableType=u'TEMPORARY', isTemporary=True)]
+ >>> spark.catalog.listTables("some_db")
+ [Table(name=u'my_tab2', database=u'some_db', description=None, tableType=u'MANAGED',
+ isTemporary=False), Table(name=u'my_temp_tab', database=None, description=None,
+ tableType=u'TEMPORARY', isTemporary=True)]
+ >>> spark.catalog.listTables("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ AnalysisException: ...
+ """
+ if dbName is None:
+ dbName = self.currentDatabase()
+ iter = self._jcatalog.listTables(dbName).toLocalIterator()
+ tables = []
+ while iter.hasNext():
+ jtable = iter.next()
+ tables.append(Table(
+ name=jtable.name(),
+ database=jtable.database(),
+ description=jtable.description(),
+ tableType=jtable.tableType(),
+ isTemporary=jtable.isTemporary()))
+ return tables
+
+ @ignore_unicode_prefix
+ @since(2.0)
+ def listFunctions(self, dbName=None):
+ """Returns a list of functions registered in the specified database.
+
+ If no database is specified, the current database is used.
+ This includes all temporary functions.
+
+ >>> spark.catalog._reset()
+ >>> spark.sql("CREATE DATABASE my_db")
+ DataFrame[]
+ >>> funcNames = set(f.name for f in spark.catalog.listFunctions())
+ >>> set(["+", "floor", "to_unix_timestamp", "current_database"]).issubset(funcNames)
+ True
+ >>> spark.sql("CREATE FUNCTION my_func1 AS 'org.apache.spark.whatever'")
+ DataFrame[]
+ >>> spark.sql("CREATE FUNCTION my_db.my_func2 AS 'org.apache.spark.whatever'")
+ DataFrame[]
+ >>> spark.catalog.registerFunction("temp_func", lambda x: str(x))
+ >>> newFuncNames = set(f.name for f in spark.catalog.listFunctions()) - funcNames
+ >>> newFuncNamesDb = set(f.name for f in spark.catalog.listFunctions("my_db")) - funcNames
+ >>> sorted(list(newFuncNames))
+ [u'my_func1', u'temp_func']
+ >>> sorted(list(newFuncNamesDb))
+ [u'my_func2', u'temp_func']
+ >>> spark.catalog.listFunctions("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ AnalysisException: ...
+ """
+ if dbName is None:
+ dbName = self.currentDatabase()
+ iter = self._jcatalog.listFunctions(dbName).toLocalIterator()
+ functions = []
+ while iter.hasNext():
+ jfunction = iter.next()
+ functions.append(Function(
+ name=jfunction.name(),
+ description=jfunction.description(),
+ className=jfunction.className(),
+ isTemporary=jfunction.isTemporary()))
+ return functions
+
+ @ignore_unicode_prefix
+ @since(2.0)
+ def listColumns(self, tableName, dbName=None):
+ """Returns a list of columns for the given table in the specified database.
+
+ If no database is specified, the current database is used.
+
+ Note: the order of arguments here is different from that of its JVM counterpart
+ because Python does not support method overloading.
+
+ >>> spark.catalog._reset()
+ >>> spark.sql("CREATE DATABASE some_db")
+ DataFrame[]
+ >>> spark.sql("CREATE TABLE my_tab1 (name STRING, age INT)")
+ DataFrame[]
+ >>> spark.sql("CREATE TABLE some_db.my_tab2 (nickname STRING, tolerance FLOAT)")
+ DataFrame[]
+ >>> spark.catalog.listColumns("my_tab1")
+ [Column(name=u'name', description=None, dataType=u'string', nullable=True,
+ isPartition=False, isBucket=False), Column(name=u'age', description=None,
+ dataType=u'int', nullable=True, isPartition=False, isBucket=False)]
+ >>> spark.catalog.listColumns("my_tab2", "some_db")
+ [Column(name=u'nickname', description=None, dataType=u'string', nullable=True,
+ isPartition=False, isBucket=False), Column(name=u'tolerance', description=None,
+ dataType=u'float', nullable=True, isPartition=False, isBucket=False)]
+ >>> spark.catalog.listColumns("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ AnalysisException: ...
+ """
+ if dbName is None:
+ dbName = self.currentDatabase()
+ iter = self._jcatalog.listColumns(dbName, tableName).toLocalIterator()
+ columns = []
+ while iter.hasNext():
+ jcolumn = iter.next()
+ columns.append(Column(
+ name=jcolumn.name(),
+ description=jcolumn.description(),
+ dataType=jcolumn.dataType(),
+ nullable=jcolumn.nullable(),
+ isPartition=jcolumn.isPartition(),
+ isBucket=jcolumn.isBucket()))
+ return columns
+
+ @since(2.0)
+ def createExternalTable(self, tableName, path=None, source=None, schema=None, **options):
+ """Creates an external table based on the dataset in a data source.
+
+ It returns the DataFrame associated with the external table.
+
+ The data source is specified by the ``source`` and a set of ``options``.
+ If ``source`` is not specified, the default data source configured by
+ ``spark.sql.sources.default`` will be used.
+
+ Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and
+ created external table.
+
+ :return: :class:`DataFrame`
+ """
+ if path is not None:
+ options["path"] = path
+ if source is None:
+ source = self._sparkSession.getConf(
+ "spark.sql.sources.default", "org.apache.spark.sql.parquet")
+ if schema is None:
+ df = self._jcatalog.createExternalTable(tableName, source, options)
+ else:
+ if not isinstance(schema, StructType):
+ raise TypeError("schema should be StructType")
+ scala_datatype = self._jsparkSession.parseDataType(schema.json())
+ df = self._jcatalog.createExternalTable(tableName, source, scala_datatype, options)
+ return DataFrame(df, self._sparkSession._wrapped)
+
+ @since(2.0)
+ def dropTempTable(self, tableName):
+ """Drops the temporary table with the given table name in the catalog.
+ If the table has been cached before, then it will also be uncached.
+
+ >>> spark.createDataFrame([(1, 1)]).registerTempTable("my_table")
+ >>> spark.table("my_table").collect()
+ [Row(_1=1, _2=1)]
+ >>> spark.catalog.dropTempTable("my_table")
+ >>> spark.table("my_table") # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ AnalysisException: ...
+ """
+ self._jcatalog.dropTempTable(tableName)
+
+ @since(2.0)
+ def registerDataFrameAsTable(self, df, tableName):
+ """Registers the given :class:`DataFrame` as a temporary table in the catalog.
+
+ >>> df = spark.createDataFrame([(2, 1), (3, 1)])
+ >>> spark.catalog.registerDataFrameAsTable(df, "my_cool_table")
+ >>> spark.table("my_cool_table").collect()
+ [Row(_1=2, _2=1), Row(_1=3, _2=1)]
+ """
+ if isinstance(df, DataFrame):
+ self._jsparkSession.registerDataFrameAsTable(df._jdf, tableName)
+ else:
+ raise ValueError("Can only register DataFrame as table")
+
+ @ignore_unicode_prefix
+ @since(2.0)
+ def registerFunction(self, name, f, returnType=StringType()):
+ """Registers a python function (including lambda function) as a UDF
+ so it can be used in SQL statements.
+
+ In addition to a name and the function itself, the return type can be optionally specified.
+ When the return type is not given it default to a string and conversion will automatically
+ be done. For any other return type, the produced object must match the specified type.
+
+ :param name: name of the UDF
+ :param f: python function
+ :param returnType: a :class:`DataType` object
+
+ >>> spark.catalog.registerFunction("stringLengthString", lambda x: len(x))
+ >>> spark.sql("SELECT stringLengthString('test')").collect()
+ [Row(stringLengthString(test)=u'4')]
+
+ >>> from pyspark.sql.types import IntegerType
+ >>> spark.catalog.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
+ >>> spark.sql("SELECT stringLengthInt('test')").collect()
+ [Row(stringLengthInt(test)=4)]
+
+ >>> from pyspark.sql.types import IntegerType
+ >>> spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
+ >>> spark.sql("SELECT stringLengthInt('test')").collect()
+ [Row(stringLengthInt(test)=4)]
+ """
+ udf = UserDefinedFunction(f, returnType, name)
+ self._jsparkSession.udf().registerPython(name, udf._judf)
+
+ @since(2.0)
+ def isCached(self, tableName):
+ """Returns true if the table is currently cached in-memory.
+
+ >>> spark.catalog._reset()
+ >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab")
+ >>> spark.catalog.isCached("my_tab")
+ False
+ >>> spark.catalog.cacheTable("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ AnalysisException: ...
+ """
+ return self._jcatalog.isCached(tableName)
+
+ @since(2.0)
+ def cacheTable(self, tableName):
+ """Caches the specified table in-memory.
+
+ >>> spark.catalog._reset()
+ >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab")
+ >>> spark.catalog.isCached("my_tab")
+ False
+ >>> spark.catalog.cacheTable("my_tab")
+ >>> spark.catalog.isCached("my_tab")
+ True
+ >>> spark.catalog.cacheTable("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ AnalysisException: ...
+ """
+ self._jcatalog.cacheTable(tableName)
+
+ @since(2.0)
+ def uncacheTable(self, tableName):
+ """Removes the specified table from the in-memory cache.
+
+ >>> spark.catalog._reset()
+ >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab")
+ >>> spark.catalog.cacheTable("my_tab")
+ >>> spark.catalog.isCached("my_tab")
+ True
+ >>> spark.catalog.uncacheTable("my_tab")
+ >>> spark.catalog.isCached("my_tab")
+ False
+ >>> spark.catalog.uncacheTable("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ AnalysisException: ...
+ """
+ self._jcatalog.uncacheTable(tableName)
+
+ @since(2.0)
+ def clearCache(self):
+ """Removes all cached tables from the in-memory cache.
+
+ >>> spark.catalog._reset()
+ >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab1")
+ >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab2")
+ >>> spark.catalog.cacheTable("my_tab1")
+ >>> spark.catalog.cacheTable("my_tab2")
+ >>> spark.catalog.isCached("my_tab1")
+ True
+ >>> spark.catalog.isCached("my_tab2")
+ True
+ >>> spark.catalog.clearCache()
+ >>> spark.catalog.isCached("my_tab1")
+ False
+ >>> spark.catalog.isCached("my_tab2")
+ False
+ """
+ self._jcatalog.clearCache()
+
+ def _reset(self):
+ """(Internal use only) Drop all existing databases (except "default"), tables,
+ partitions and functions, and set the current database to "default".
+
+ This is mainly used for tests.
+ """
+ self._jsparkSession.sessionState().catalog().reset()
+
+
+def _test():
+ import os
+ import doctest
+ from pyspark.context import SparkContext
+ from pyspark.sql.session import SparkSession
+ import pyspark.sql.catalog
+
+ os.chdir(os.environ["SPARK_HOME"])
+
+ globs = pyspark.sql.catalog.__dict__.copy()
+ sc = SparkContext('local[4]', 'PythonTest')
+ globs['sc'] = sc
+ globs['spark'] = SparkSession(sc)
+ (failure_count, test_count) = doctest.testmod(
+ pyspark.sql.catalog,
+ globs=globs,
+ optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py
new file mode 100644
index 0000000000..1d9f052e25
--- /dev/null
+++ b/python/pyspark/sql/conf.py
@@ -0,0 +1,114 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark import since
+from pyspark.rdd import ignore_unicode_prefix
+
+
+class RuntimeConfig(object):
+ """User-facing configuration API, accessible through `SparkSession.conf`.
+
+ Options set here are automatically propagated to the Hadoop configuration during I/O.
+ This a thin wrapper around its Scala implementation org.apache.spark.sql.RuntimeConfig.
+ """
+
+ def __init__(self, jconf):
+ """Create a new RuntimeConfig that wraps the underlying JVM object."""
+ self._jconf = jconf
+
+ @ignore_unicode_prefix
+ @since(2.0)
+ def set(self, key, value):
+ """Sets the given Spark runtime configuration property.
+
+ >>> spark.conf.set("garble", "marble")
+ >>> spark.getConf("garble")
+ u'marble'
+ """
+ self._jconf.set(key, value)
+
+ @ignore_unicode_prefix
+ @since(2.0)
+ def get(self, key):
+ """Returns the value of Spark runtime configuration property for the given key,
+ assuming it is set.
+
+ >>> spark.setConf("bogo", "sipeo")
+ >>> spark.conf.get("bogo")
+ u'sipeo'
+ >>> spark.conf.get("definitely.not.set") # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ Py4JJavaError: ...
+ """
+ return self._jconf.get(key)
+
+ @ignore_unicode_prefix
+ @since(2.0)
+ def getOption(self, key):
+ """Returns the value of Spark runtime configuration property for the given key,
+ or None if it is not set.
+
+ >>> spark.setConf("bogo", "sipeo")
+ >>> spark.conf.getOption("bogo")
+ u'sipeo'
+ >>> spark.conf.getOption("definitely.not.set") is None
+ True
+ """
+ iter = self._jconf.getOption(key).iterator()
+ if iter.hasNext():
+ return iter.next()
+ else:
+ return None
+
+ @ignore_unicode_prefix
+ @since(2.0)
+ def unset(self, key):
+ """Resets the configuration property for the given key.
+
+ >>> spark.setConf("armado", "larmado")
+ >>> spark.getConf("armado")
+ u'larmado'
+ >>> spark.conf.unset("armado")
+ >>> spark.getConf("armado") # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ Py4JJavaError: ...
+ """
+ self._jconf.unset(key)
+
+
+def _test():
+ import os
+ import doctest
+ from pyspark.context import SparkContext
+ from pyspark.sql.session import SparkSession
+ import pyspark.sql.conf
+
+ os.chdir(os.environ["SPARK_HOME"])
+
+ globs = pyspark.sql.conf.__dict__.copy()
+ sc = SparkContext('local[4]', 'PythonTest')
+ globs['sc'] = sc
+ globs['spark'] = SparkSession(sc)
+ (failure_count, test_count) = doctest.testmod(pyspark.sql.conf, globs=globs)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index a3ea192b28..94856c245b 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -142,7 +142,7 @@ class SQLContext(object):
:return: :class:`UDFRegistration`
"""
- return UDFRegistration(self.sparkSession)
+ return UDFRegistration(self)
@since(1.4)
def range(self, start, end=None, step=1, numPartitions=None):
@@ -195,7 +195,7 @@ class SQLContext(object):
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
"""
- self.sparkSession.registerFunction(name, f, returnType)
+ self.sparkSession.catalog.registerFunction(name, f, returnType)
# TODO(andrew): delete this once we refactor things to take in SparkSession
def _inferSchema(self, rdd, samplingRatio=None):
@@ -301,7 +301,7 @@ class SQLContext(object):
>>> sqlContext.registerDataFrameAsTable(df, "table1")
"""
- self.sparkSession.registerDataFrameAsTable(df, tableName)
+ self.sparkSession.catalog.registerDataFrameAsTable(df, tableName)
@since(1.6)
def dropTempTable(self, tableName):
@@ -310,7 +310,7 @@ class SQLContext(object):
>>> sqlContext.registerDataFrameAsTable(df, "table1")
>>> sqlContext.dropTempTable("table1")
"""
- self._ssql_ctx.dropTempTable(tableName)
+ self.sparkSession.catalog.dropTempTable(tableName)
@since(1.3)
def createExternalTable(self, tableName, path=None, source=None, schema=None, **options):
@@ -327,7 +327,8 @@ class SQLContext(object):
:return: :class:`DataFrame`
"""
- return self.sparkSession.createExternalTable(tableName, path, source, schema, **options)
+ return self.sparkSession.catalog.createExternalTable(
+ tableName, path, source, schema, **options)
@ignore_unicode_prefix
@since(1.0)
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index d3355f9da7..b3bc8961b8 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -27,8 +27,9 @@ else:
from pyspark import since
from pyspark.rdd import RDD, ignore_unicode_prefix
+from pyspark.sql.catalog import Catalog
+from pyspark.sql.conf import RuntimeConfig
from pyspark.sql.dataframe import DataFrame
-from pyspark.sql.functions import UserDefinedFunction
from pyspark.sql.readwriter import DataFrameReader
from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \
_infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string
@@ -56,7 +57,6 @@ def _monkey_patch_RDD(sparkSession):
RDD.toDF = toDF
-# TODO(andrew): implement conf and catalog namespaces
class SparkSession(object):
"""Main entry point for Spark SQL functionality.
@@ -105,7 +105,7 @@ class SparkSession(object):
@classmethod
@since(2.0)
def withHiveSupport(cls, sparkContext):
- """Returns a new SparkSession with a catalog backed by Hive
+ """Returns a new SparkSession with a catalog backed by Hive.
:param sparkContext: The underlying :class:`SparkContext`.
"""
@@ -121,6 +121,19 @@ class SparkSession(object):
"""
return self.__class__(self._sc, self._jsparkSession.newSession())
+ @property
+ @since(2.0)
+ def conf(self):
+ """Runtime configuration interface for Spark.
+
+ This is the interface through which the user can get and set all Spark and Hadoop
+ configurations that are relevant to Spark SQL. When getting the value of a config,
+ this defaults to the value set in the underlying :class:`SparkContext`, if any.
+ """
+ if not hasattr(self, "_conf"):
+ self._conf = RuntimeConfig(self._jsparkSession.conf())
+ return self._conf
+
@since(2.0)
def setConf(self, key, value):
"""
@@ -152,12 +165,23 @@ class SparkSession(object):
@property
@since(2.0)
+ def catalog(self):
+ """Interface through which the user may create, drop, alter or query underlying
+ databases, tables, functions etc.
+ """
+ if not hasattr(self, "_catalog"):
+ self._catalog = Catalog(self)
+ return self._catalog
+
+ @property
+ @since(2.0)
def udf(self):
"""Returns a :class:`UDFRegistration` for UDF registration.
:return: :class:`UDFRegistration`
"""
- return UDFRegistration(self)
+ from pyspark.sql.context import UDFRegistration
+ return UDFRegistration(self._wrapped)
@since(2.0)
def range(self, start, end=None, step=1, numPartitions=None):
@@ -190,37 +214,6 @@ class SparkSession(object):
return DataFrame(jdf, self._wrapped)
- @ignore_unicode_prefix
- @since(2.0)
- def registerFunction(self, name, f, returnType=StringType()):
- """Registers a python function (including lambda function) as a UDF
- so it can be used in SQL statements.
-
- In addition to a name and the function itself, the return type can be optionally specified.
- When the return type is not given it default to a string and conversion will automatically
- be done. For any other return type, the produced object must match the specified type.
-
- :param name: name of the UDF
- :param f: python function
- :param returnType: a :class:`DataType` object
-
- >>> spark.registerFunction("stringLengthString", lambda x: len(x))
- >>> spark.sql("SELECT stringLengthString('test')").collect()
- [Row(stringLengthString(test)=u'4')]
-
- >>> from pyspark.sql.types import IntegerType
- >>> spark.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
- >>> spark.sql("SELECT stringLengthInt('test')").collect()
- [Row(stringLengthInt(test)=4)]
-
- >>> from pyspark.sql.types import IntegerType
- >>> spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
- >>> spark.sql("SELECT stringLengthInt('test')").collect()
- [Row(stringLengthInt(test)=4)]
- """
- udf = UserDefinedFunction(f, returnType, name)
- self._jsparkSession.udf().registerPython(name, udf._judf)
-
def _inferSchemaFromList(self, data):
"""
Infer schema from list of Row or tuple.
@@ -443,49 +436,6 @@ class SparkSession(object):
df._schema = schema
return df
- @since(2.0)
- def registerDataFrameAsTable(self, df, tableName):
- """Registers the given :class:`DataFrame` as a temporary table in the catalog.
-
- Temporary tables exist only during the lifetime of this instance of :class:`SparkSession`.
-
- >>> spark.registerDataFrameAsTable(df, "table1")
- """
- if (df.__class__ is DataFrame):
- self._jsparkSession.registerDataFrameAsTable(df._jdf, tableName)
- else:
- raise ValueError("Can only register DataFrame as table")
-
- @since(2.0)
- def createExternalTable(self, tableName, path=None, source=None, schema=None, **options):
- """Creates an external table based on the dataset in a data source.
-
- It returns the DataFrame associated with the external table.
-
- The data source is specified by the ``source`` and a set of ``options``.
- If ``source`` is not specified, the default data source configured by
- ``spark.sql.sources.default`` will be used.
-
- Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and
- created external table.
-
- :return: :class:`DataFrame`
- """
- if path is not None:
- options["path"] = path
- if source is None:
- source = self.getConf("spark.sql.sources.default",
- "org.apache.spark.sql.parquet")
- if schema is None:
- df = self._jsparkSession.catalog().createExternalTable(tableName, source, options)
- else:
- if not isinstance(schema, StructType):
- raise TypeError("schema should be StructType")
- scala_datatype = self._jsparkSession.parseDataType(schema.json())
- df = self._jsparkSession.catalog().createExternalTable(
- tableName, source, scala_datatype, options)
- return DataFrame(df, self._wrapped)
-
@ignore_unicode_prefix
@since(2.0)
def sql(self, sqlQuery):
@@ -493,7 +443,7 @@ class SparkSession(object):
:return: :class:`DataFrame`
- >>> spark.registerDataFrameAsTable(df, "table1")
+ >>> spark.catalog.registerDataFrameAsTable(df, "table1")
>>> df2 = spark.sql("SELECT field1 AS f1, field2 as f2 from table1")
>>> df2.collect()
[Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
@@ -506,7 +456,7 @@ class SparkSession(object):
:return: :class:`DataFrame`
- >>> spark.registerDataFrameAsTable(df, "table1")
+ >>> spark.catalog.registerDataFrameAsTable(df, "table1")
>>> df2 = spark.table("table1")
>>> sorted(df.collect()) == sorted(df2.collect())
True
@@ -523,3 +473,32 @@ class SparkSession(object):
:return: :class:`DataFrameReader`
"""
return DataFrameReader(self._wrapped)
+
+
+def _test():
+ import os
+ import doctest
+ from pyspark.context import SparkContext
+ from pyspark.sql import Row
+ import pyspark.sql.session
+
+ os.chdir(os.environ["SPARK_HOME"])
+
+ globs = pyspark.sql.session.__dict__.copy()
+ sc = SparkContext('local[4]', 'PythonTest')
+ globs['sc'] = sc
+ globs['spark'] = SparkSession(sc)
+ globs['rdd'] = rdd = sc.parallelize(
+ [Row(field1=1, field2="row1"),
+ Row(field1=2, field2="row2"),
+ Row(field1=3, field2="row3")])
+ globs['df'] = rdd.toDF()
+ (failure_count, test_count) = doctest.testmod(
+ pyspark.sql.session, globs=globs,
+ optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+if __name__ == "__main__":
+ _test()