From a7d0fedc940721d09350f2e57ae85591e0a3d90e Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 29 Apr 2016 09:34:10 -0700 Subject: [SPARK-14988][PYTHON] SparkSession catalog and conf API ## What changes were proposed in this pull request? The `catalog` and `conf` APIs were exposed in `SparkSession` in #12713 and #12669. This patch adds those to the python API. ## How was this patch tested? Python tests. Author: Andrew Or Closes #12765 from andrewor14/python-spark-session-more. --- dev/sparktestsupport/modules.py | 3 + python/pyspark/sql/catalog.py | 426 +++++++++++++++++++++ python/pyspark/sql/conf.py | 114 ++++++ python/pyspark/sql/context.py | 11 +- python/pyspark/sql/session.py | 139 +++---- .../org/apache/spark/sql/catalog/Catalog.scala | 4 +- .../org/apache/spark/sql/catalog/interface.scala | 1 + 7 files changed, 611 insertions(+), 87 deletions(-) create mode 100644 python/pyspark/sql/catalog.py create mode 100644 python/pyspark/sql/conf.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 6d47733ec1..5640928643 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -334,6 +334,9 @@ pyspark_sql = Module( python_test_goals=[ "pyspark.sql.types", "pyspark.sql.context", + "pyspark.sql.session", + "pyspark.sql.conf", + "pyspark.sql.catalog", "pyspark.sql.column", "pyspark.sql.dataframe", "pyspark.sql.group", diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py new file mode 100644 index 0000000000..4f9238374a --- /dev/null +++ b/python/pyspark/sql/catalog.py @@ -0,0 +1,426 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import namedtuple + +from pyspark import since +from pyspark.rdd import ignore_unicode_prefix +from pyspark.sql.dataframe import DataFrame +from pyspark.sql.functions import UserDefinedFunction +from pyspark.sql.types import IntegerType, StringType, StructType + + +Database = namedtuple("Database", "name description locationUri") +Table = namedtuple("Table", "name database description tableType isTemporary") +Column = namedtuple("Column", "name description dataType nullable isPartition isBucket") +Function = namedtuple("Function", "name description className isTemporary") + + +class Catalog(object): + """User-facing catalog API, accessible through `SparkSession.catalog`. + + This is a thin wrapper around its Scala implementation org.apache.spark.sql.catalog.Catalog. + """ + + def __init__(self, sparkSession): + """Create a new Catalog that wraps the underlying JVM object.""" + self._sparkSession = sparkSession + self._jsparkSession = sparkSession._jsparkSession + self._jcatalog = sparkSession._jsparkSession.catalog() + + @ignore_unicode_prefix + @since(2.0) + def currentDatabase(self): + """Returns the current default database in this session. + + >>> spark.catalog._reset() + >>> spark.catalog.currentDatabase() + u'default' + """ + return self._jcatalog.currentDatabase() + + @ignore_unicode_prefix + @since(2.0) + def setCurrentDatabase(self, dbName): + """Sets the current default database in this session. + + >>> spark.catalog._reset() + >>> spark.sql("CREATE DATABASE some_db") + DataFrame[] + >>> spark.catalog.setCurrentDatabase("some_db") + >>> spark.catalog.currentDatabase() + u'some_db' + >>> spark.catalog.setCurrentDatabase("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + AnalysisException: ... + """ + return self._jcatalog.setCurrentDatabase(dbName) + + @ignore_unicode_prefix + @since(2.0) + def listDatabases(self): + """Returns a list of databases available across all sessions. + + >>> spark.catalog._reset() + >>> [db.name for db in spark.catalog.listDatabases()] + [u'default'] + >>> spark.sql("CREATE DATABASE some_db") + DataFrame[] + >>> [db.name for db in spark.catalog.listDatabases()] + [u'default', u'some_db'] + """ + iter = self._jcatalog.listDatabases().toLocalIterator() + databases = [] + while iter.hasNext(): + jdb = iter.next() + databases.append(Database( + name=jdb.name(), + description=jdb.description(), + locationUri=jdb.locationUri())) + return databases + + @ignore_unicode_prefix + @since(2.0) + def listTables(self, dbName=None): + """Returns a list of tables in the specified database. + + If no database is specified, the current database is used. + This includes all temporary tables. + + >>> spark.catalog._reset() + >>> spark.sql("CREATE DATABASE some_db") + DataFrame[] + >>> spark.catalog.listTables() + [] + >>> spark.catalog.listTables("some_db") + [] + >>> spark.createDataFrame([(1, 1)]).registerTempTable("my_temp_tab") + >>> spark.sql("CREATE TABLE my_tab1 (name STRING, age INT)") + DataFrame[] + >>> spark.sql("CREATE TABLE some_db.my_tab2 (name STRING, age INT)") + DataFrame[] + >>> spark.catalog.listTables() + [Table(name=u'my_tab1', database=u'default', description=None, tableType=u'MANAGED', + isTemporary=False), Table(name=u'my_temp_tab', database=None, description=None, + tableType=u'TEMPORARY', isTemporary=True)] + >>> spark.catalog.listTables("some_db") + [Table(name=u'my_tab2', database=u'some_db', description=None, tableType=u'MANAGED', + isTemporary=False), Table(name=u'my_temp_tab', database=None, description=None, + tableType=u'TEMPORARY', isTemporary=True)] + >>> spark.catalog.listTables("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + AnalysisException: ... + """ + if dbName is None: + dbName = self.currentDatabase() + iter = self._jcatalog.listTables(dbName).toLocalIterator() + tables = [] + while iter.hasNext(): + jtable = iter.next() + tables.append(Table( + name=jtable.name(), + database=jtable.database(), + description=jtable.description(), + tableType=jtable.tableType(), + isTemporary=jtable.isTemporary())) + return tables + + @ignore_unicode_prefix + @since(2.0) + def listFunctions(self, dbName=None): + """Returns a list of functions registered in the specified database. + + If no database is specified, the current database is used. + This includes all temporary functions. + + >>> spark.catalog._reset() + >>> spark.sql("CREATE DATABASE my_db") + DataFrame[] + >>> funcNames = set(f.name for f in spark.catalog.listFunctions()) + >>> set(["+", "floor", "to_unix_timestamp", "current_database"]).issubset(funcNames) + True + >>> spark.sql("CREATE FUNCTION my_func1 AS 'org.apache.spark.whatever'") + DataFrame[] + >>> spark.sql("CREATE FUNCTION my_db.my_func2 AS 'org.apache.spark.whatever'") + DataFrame[] + >>> spark.catalog.registerFunction("temp_func", lambda x: str(x)) + >>> newFuncNames = set(f.name for f in spark.catalog.listFunctions()) - funcNames + >>> newFuncNamesDb = set(f.name for f in spark.catalog.listFunctions("my_db")) - funcNames + >>> sorted(list(newFuncNames)) + [u'my_func1', u'temp_func'] + >>> sorted(list(newFuncNamesDb)) + [u'my_func2', u'temp_func'] + >>> spark.catalog.listFunctions("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + AnalysisException: ... + """ + if dbName is None: + dbName = self.currentDatabase() + iter = self._jcatalog.listFunctions(dbName).toLocalIterator() + functions = [] + while iter.hasNext(): + jfunction = iter.next() + functions.append(Function( + name=jfunction.name(), + description=jfunction.description(), + className=jfunction.className(), + isTemporary=jfunction.isTemporary())) + return functions + + @ignore_unicode_prefix + @since(2.0) + def listColumns(self, tableName, dbName=None): + """Returns a list of columns for the given table in the specified database. + + If no database is specified, the current database is used. + + Note: the order of arguments here is different from that of its JVM counterpart + because Python does not support method overloading. + + >>> spark.catalog._reset() + >>> spark.sql("CREATE DATABASE some_db") + DataFrame[] + >>> spark.sql("CREATE TABLE my_tab1 (name STRING, age INT)") + DataFrame[] + >>> spark.sql("CREATE TABLE some_db.my_tab2 (nickname STRING, tolerance FLOAT)") + DataFrame[] + >>> spark.catalog.listColumns("my_tab1") + [Column(name=u'name', description=None, dataType=u'string', nullable=True, + isPartition=False, isBucket=False), Column(name=u'age', description=None, + dataType=u'int', nullable=True, isPartition=False, isBucket=False)] + >>> spark.catalog.listColumns("my_tab2", "some_db") + [Column(name=u'nickname', description=None, dataType=u'string', nullable=True, + isPartition=False, isBucket=False), Column(name=u'tolerance', description=None, + dataType=u'float', nullable=True, isPartition=False, isBucket=False)] + >>> spark.catalog.listColumns("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + AnalysisException: ... + """ + if dbName is None: + dbName = self.currentDatabase() + iter = self._jcatalog.listColumns(dbName, tableName).toLocalIterator() + columns = [] + while iter.hasNext(): + jcolumn = iter.next() + columns.append(Column( + name=jcolumn.name(), + description=jcolumn.description(), + dataType=jcolumn.dataType(), + nullable=jcolumn.nullable(), + isPartition=jcolumn.isPartition(), + isBucket=jcolumn.isBucket())) + return columns + + @since(2.0) + def createExternalTable(self, tableName, path=None, source=None, schema=None, **options): + """Creates an external table based on the dataset in a data source. + + It returns the DataFrame associated with the external table. + + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. + + Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and + created external table. + + :return: :class:`DataFrame` + """ + if path is not None: + options["path"] = path + if source is None: + source = self._sparkSession.getConf( + "spark.sql.sources.default", "org.apache.spark.sql.parquet") + if schema is None: + df = self._jcatalog.createExternalTable(tableName, source, options) + else: + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + scala_datatype = self._jsparkSession.parseDataType(schema.json()) + df = self._jcatalog.createExternalTable(tableName, source, scala_datatype, options) + return DataFrame(df, self._sparkSession._wrapped) + + @since(2.0) + def dropTempTable(self, tableName): + """Drops the temporary table with the given table name in the catalog. + If the table has been cached before, then it will also be uncached. + + >>> spark.createDataFrame([(1, 1)]).registerTempTable("my_table") + >>> spark.table("my_table").collect() + [Row(_1=1, _2=1)] + >>> spark.catalog.dropTempTable("my_table") + >>> spark.table("my_table") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + AnalysisException: ... + """ + self._jcatalog.dropTempTable(tableName) + + @since(2.0) + def registerDataFrameAsTable(self, df, tableName): + """Registers the given :class:`DataFrame` as a temporary table in the catalog. + + >>> df = spark.createDataFrame([(2, 1), (3, 1)]) + >>> spark.catalog.registerDataFrameAsTable(df, "my_cool_table") + >>> spark.table("my_cool_table").collect() + [Row(_1=2, _2=1), Row(_1=3, _2=1)] + """ + if isinstance(df, DataFrame): + self._jsparkSession.registerDataFrameAsTable(df._jdf, tableName) + else: + raise ValueError("Can only register DataFrame as table") + + @ignore_unicode_prefix + @since(2.0) + def registerFunction(self, name, f, returnType=StringType()): + """Registers a python function (including lambda function) as a UDF + so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not given it default to a string and conversion will automatically + be done. For any other return type, the produced object must match the specified type. + + :param name: name of the UDF + :param f: python function + :param returnType: a :class:`DataType` object + + >>> spark.catalog.registerFunction("stringLengthString", lambda x: len(x)) + >>> spark.sql("SELECT stringLengthString('test')").collect() + [Row(stringLengthString(test)=u'4')] + + >>> from pyspark.sql.types import IntegerType + >>> spark.catalog.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> spark.sql("SELECT stringLengthInt('test')").collect() + [Row(stringLengthInt(test)=4)] + + >>> from pyspark.sql.types import IntegerType + >>> spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> spark.sql("SELECT stringLengthInt('test')").collect() + [Row(stringLengthInt(test)=4)] + """ + udf = UserDefinedFunction(f, returnType, name) + self._jsparkSession.udf().registerPython(name, udf._judf) + + @since(2.0) + def isCached(self, tableName): + """Returns true if the table is currently cached in-memory. + + >>> spark.catalog._reset() + >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab") + >>> spark.catalog.isCached("my_tab") + False + >>> spark.catalog.cacheTable("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + AnalysisException: ... + """ + return self._jcatalog.isCached(tableName) + + @since(2.0) + def cacheTable(self, tableName): + """Caches the specified table in-memory. + + >>> spark.catalog._reset() + >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab") + >>> spark.catalog.isCached("my_tab") + False + >>> spark.catalog.cacheTable("my_tab") + >>> spark.catalog.isCached("my_tab") + True + >>> spark.catalog.cacheTable("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + AnalysisException: ... + """ + self._jcatalog.cacheTable(tableName) + + @since(2.0) + def uncacheTable(self, tableName): + """Removes the specified table from the in-memory cache. + + >>> spark.catalog._reset() + >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab") + >>> spark.catalog.cacheTable("my_tab") + >>> spark.catalog.isCached("my_tab") + True + >>> spark.catalog.uncacheTable("my_tab") + >>> spark.catalog.isCached("my_tab") + False + >>> spark.catalog.uncacheTable("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + AnalysisException: ... + """ + self._jcatalog.uncacheTable(tableName) + + @since(2.0) + def clearCache(self): + """Removes all cached tables from the in-memory cache. + + >>> spark.catalog._reset() + >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab1") + >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab2") + >>> spark.catalog.cacheTable("my_tab1") + >>> spark.catalog.cacheTable("my_tab2") + >>> spark.catalog.isCached("my_tab1") + True + >>> spark.catalog.isCached("my_tab2") + True + >>> spark.catalog.clearCache() + >>> spark.catalog.isCached("my_tab1") + False + >>> spark.catalog.isCached("my_tab2") + False + """ + self._jcatalog.clearCache() + + def _reset(self): + """(Internal use only) Drop all existing databases (except "default"), tables, + partitions and functions, and set the current database to "default". + + This is mainly used for tests. + """ + self._jsparkSession.sessionState().catalog().reset() + + +def _test(): + import os + import doctest + from pyspark.context import SparkContext + from pyspark.sql.session import SparkSession + import pyspark.sql.catalog + + os.chdir(os.environ["SPARK_HOME"]) + + globs = pyspark.sql.catalog.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['spark'] = SparkSession(sc) + (failure_count, test_count) = doctest.testmod( + pyspark.sql.catalog, + globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py new file mode 100644 index 0000000000..1d9f052e25 --- /dev/null +++ b/python/pyspark/sql/conf.py @@ -0,0 +1,114 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import since +from pyspark.rdd import ignore_unicode_prefix + + +class RuntimeConfig(object): + """User-facing configuration API, accessible through `SparkSession.conf`. + + Options set here are automatically propagated to the Hadoop configuration during I/O. + This a thin wrapper around its Scala implementation org.apache.spark.sql.RuntimeConfig. + """ + + def __init__(self, jconf): + """Create a new RuntimeConfig that wraps the underlying JVM object.""" + self._jconf = jconf + + @ignore_unicode_prefix + @since(2.0) + def set(self, key, value): + """Sets the given Spark runtime configuration property. + + >>> spark.conf.set("garble", "marble") + >>> spark.getConf("garble") + u'marble' + """ + self._jconf.set(key, value) + + @ignore_unicode_prefix + @since(2.0) + def get(self, key): + """Returns the value of Spark runtime configuration property for the given key, + assuming it is set. + + >>> spark.setConf("bogo", "sipeo") + >>> spark.conf.get("bogo") + u'sipeo' + >>> spark.conf.get("definitely.not.set") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + Py4JJavaError: ... + """ + return self._jconf.get(key) + + @ignore_unicode_prefix + @since(2.0) + def getOption(self, key): + """Returns the value of Spark runtime configuration property for the given key, + or None if it is not set. + + >>> spark.setConf("bogo", "sipeo") + >>> spark.conf.getOption("bogo") + u'sipeo' + >>> spark.conf.getOption("definitely.not.set") is None + True + """ + iter = self._jconf.getOption(key).iterator() + if iter.hasNext(): + return iter.next() + else: + return None + + @ignore_unicode_prefix + @since(2.0) + def unset(self, key): + """Resets the configuration property for the given key. + + >>> spark.setConf("armado", "larmado") + >>> spark.getConf("armado") + u'larmado' + >>> spark.conf.unset("armado") + >>> spark.getConf("armado") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + Py4JJavaError: ... + """ + self._jconf.unset(key) + + +def _test(): + import os + import doctest + from pyspark.context import SparkContext + from pyspark.sql.session import SparkSession + import pyspark.sql.conf + + os.chdir(os.environ["SPARK_HOME"]) + + globs = pyspark.sql.conf.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['spark'] = SparkSession(sc) + (failure_count, test_count) = doctest.testmod(pyspark.sql.conf, globs=globs) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index a3ea192b28..94856c245b 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -142,7 +142,7 @@ class SQLContext(object): :return: :class:`UDFRegistration` """ - return UDFRegistration(self.sparkSession) + return UDFRegistration(self) @since(1.4) def range(self, start, end=None, step=1, numPartitions=None): @@ -195,7 +195,7 @@ class SQLContext(object): >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] """ - self.sparkSession.registerFunction(name, f, returnType) + self.sparkSession.catalog.registerFunction(name, f, returnType) # TODO(andrew): delete this once we refactor things to take in SparkSession def _inferSchema(self, rdd, samplingRatio=None): @@ -301,7 +301,7 @@ class SQLContext(object): >>> sqlContext.registerDataFrameAsTable(df, "table1") """ - self.sparkSession.registerDataFrameAsTable(df, tableName) + self.sparkSession.catalog.registerDataFrameAsTable(df, tableName) @since(1.6) def dropTempTable(self, tableName): @@ -310,7 +310,7 @@ class SQLContext(object): >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> sqlContext.dropTempTable("table1") """ - self._ssql_ctx.dropTempTable(tableName) + self.sparkSession.catalog.dropTempTable(tableName) @since(1.3) def createExternalTable(self, tableName, path=None, source=None, schema=None, **options): @@ -327,7 +327,8 @@ class SQLContext(object): :return: :class:`DataFrame` """ - return self.sparkSession.createExternalTable(tableName, path, source, schema, **options) + return self.sparkSession.catalog.createExternalTable( + tableName, path, source, schema, **options) @ignore_unicode_prefix @since(1.0) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index d3355f9da7..b3bc8961b8 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -27,8 +27,9 @@ else: from pyspark import since from pyspark.rdd import RDD, ignore_unicode_prefix +from pyspark.sql.catalog import Catalog +from pyspark.sql.conf import RuntimeConfig from pyspark.sql.dataframe import DataFrame -from pyspark.sql.functions import UserDefinedFunction from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \ _infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string @@ -56,7 +57,6 @@ def _monkey_patch_RDD(sparkSession): RDD.toDF = toDF -# TODO(andrew): implement conf and catalog namespaces class SparkSession(object): """Main entry point for Spark SQL functionality. @@ -105,7 +105,7 @@ class SparkSession(object): @classmethod @since(2.0) def withHiveSupport(cls, sparkContext): - """Returns a new SparkSession with a catalog backed by Hive + """Returns a new SparkSession with a catalog backed by Hive. :param sparkContext: The underlying :class:`SparkContext`. """ @@ -121,6 +121,19 @@ class SparkSession(object): """ return self.__class__(self._sc, self._jsparkSession.newSession()) + @property + @since(2.0) + def conf(self): + """Runtime configuration interface for Spark. + + This is the interface through which the user can get and set all Spark and Hadoop + configurations that are relevant to Spark SQL. When getting the value of a config, + this defaults to the value set in the underlying :class:`SparkContext`, if any. + """ + if not hasattr(self, "_conf"): + self._conf = RuntimeConfig(self._jsparkSession.conf()) + return self._conf + @since(2.0) def setConf(self, key, value): """ @@ -150,6 +163,16 @@ class SparkSession(object): else: return self._jsparkSession.getConf(key) + @property + @since(2.0) + def catalog(self): + """Interface through which the user may create, drop, alter or query underlying + databases, tables, functions etc. + """ + if not hasattr(self, "_catalog"): + self._catalog = Catalog(self) + return self._catalog + @property @since(2.0) def udf(self): @@ -157,7 +180,8 @@ class SparkSession(object): :return: :class:`UDFRegistration` """ - return UDFRegistration(self) + from pyspark.sql.context import UDFRegistration + return UDFRegistration(self._wrapped) @since(2.0) def range(self, start, end=None, step=1, numPartitions=None): @@ -190,37 +214,6 @@ class SparkSession(object): return DataFrame(jdf, self._wrapped) - @ignore_unicode_prefix - @since(2.0) - def registerFunction(self, name, f, returnType=StringType()): - """Registers a python function (including lambda function) as a UDF - so it can be used in SQL statements. - - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not given it default to a string and conversion will automatically - be done. For any other return type, the produced object must match the specified type. - - :param name: name of the UDF - :param f: python function - :param returnType: a :class:`DataType` object - - >>> spark.registerFunction("stringLengthString", lambda x: len(x)) - >>> spark.sql("SELECT stringLengthString('test')").collect() - [Row(stringLengthString(test)=u'4')] - - >>> from pyspark.sql.types import IntegerType - >>> spark.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) - >>> spark.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) - >>> spark.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - """ - udf = UserDefinedFunction(f, returnType, name) - self._jsparkSession.udf().registerPython(name, udf._judf) - def _inferSchemaFromList(self, data): """ Infer schema from list of Row or tuple. @@ -443,49 +436,6 @@ class SparkSession(object): df._schema = schema return df - @since(2.0) - def registerDataFrameAsTable(self, df, tableName): - """Registers the given :class:`DataFrame` as a temporary table in the catalog. - - Temporary tables exist only during the lifetime of this instance of :class:`SparkSession`. - - >>> spark.registerDataFrameAsTable(df, "table1") - """ - if (df.__class__ is DataFrame): - self._jsparkSession.registerDataFrameAsTable(df._jdf, tableName) - else: - raise ValueError("Can only register DataFrame as table") - - @since(2.0) - def createExternalTable(self, tableName, path=None, source=None, schema=None, **options): - """Creates an external table based on the dataset in a data source. - - It returns the DataFrame associated with the external table. - - The data source is specified by the ``source`` and a set of ``options``. - If ``source`` is not specified, the default data source configured by - ``spark.sql.sources.default`` will be used. - - Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and - created external table. - - :return: :class:`DataFrame` - """ - if path is not None: - options["path"] = path - if source is None: - source = self.getConf("spark.sql.sources.default", - "org.apache.spark.sql.parquet") - if schema is None: - df = self._jsparkSession.catalog().createExternalTable(tableName, source, options) - else: - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType") - scala_datatype = self._jsparkSession.parseDataType(schema.json()) - df = self._jsparkSession.catalog().createExternalTable( - tableName, source, scala_datatype, options) - return DataFrame(df, self._wrapped) - @ignore_unicode_prefix @since(2.0) def sql(self, sqlQuery): @@ -493,7 +443,7 @@ class SparkSession(object): :return: :class:`DataFrame` - >>> spark.registerDataFrameAsTable(df, "table1") + >>> spark.catalog.registerDataFrameAsTable(df, "table1") >>> df2 = spark.sql("SELECT field1 AS f1, field2 as f2 from table1") >>> df2.collect() [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] @@ -506,7 +456,7 @@ class SparkSession(object): :return: :class:`DataFrame` - >>> spark.registerDataFrameAsTable(df, "table1") + >>> spark.catalog.registerDataFrameAsTable(df, "table1") >>> df2 = spark.table("table1") >>> sorted(df.collect()) == sorted(df2.collect()) True @@ -523,3 +473,32 @@ class SparkSession(object): :return: :class:`DataFrameReader` """ return DataFrameReader(self._wrapped) + + +def _test(): + import os + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row + import pyspark.sql.session + + os.chdir(os.environ["SPARK_HOME"]) + + globs = pyspark.sql.session.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['spark'] = SparkSession(sc) + globs['rdd'] = rdd = sc.parallelize( + [Row(field1=1, field2="row1"), + Row(field1=2, field2="row2"), + Row(field1=3, field2="row3")]) + globs['df'] = rdd.toDF() + (failure_count, test_count) = doctest.testmod( + pyspark.sql.session, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 868cc3a726..7a815c1f99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -176,9 +176,9 @@ abstract class Catalog { /** * Drops the temporary table with the given table name in the catalog. - * If the table has been cached/persisted before, it's also unpersisted. + * If the table has been cached before, then it will also be uncached. * - * @param tableName the name of the table to be unregistered. + * @param tableName the name of the table to be dropped. * @since 2.0.0 */ def dropTempTable(tableName: String): Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala index d5de6cd484..0f7feb8eee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala @@ -83,6 +83,7 @@ class Column( } +// TODO(andrew): should we include the database here? class Function( val name: String, @Nullable val description: String, -- cgit v1.2.3