diff options
author | Andrew Or <andrew@databricks.com> | 2016-04-29 09:34:10 -0700 |
---|---|---|
committer | Andrew Or <andrew@databricks.com> | 2016-04-29 09:34:10 -0700 |
commit | a7d0fedc940721d09350f2e57ae85591e0a3d90e (patch) | |
tree | 5f96e980e810cd13f36658ed052a1e987c5d261c /python/pyspark/sql/session.py | |
parent | 7feeb82cb7f462e44f7e698c7c3b6ac3a77aade4 (diff) | |
download | spark-a7d0fedc940721d09350f2e57ae85591e0a3d90e.tar.gz spark-a7d0fedc940721d09350f2e57ae85591e0a3d90e.tar.bz2 spark-a7d0fedc940721d09350f2e57ae85591e0a3d90e.zip |
[SPARK-14988][PYTHON] SparkSession catalog and conf API
## What changes were proposed in this pull request?
The `catalog` and `conf` APIs were exposed in `SparkSession` in #12713 and #12669. This patch adds those to the python API.
## How was this patch tested?
Python tests.
Author: Andrew Or <andrew@databricks.com>
Closes #12765 from andrewor14/python-spark-session-more.
Diffstat (limited to 'python/pyspark/sql/session.py')
-rw-r--r-- | python/pyspark/sql/session.py | 139 |
1 files changed, 59 insertions, 80 deletions
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index d3355f9da7..b3bc8961b8 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -27,8 +27,9 @@ else: from pyspark import since from pyspark.rdd import RDD, ignore_unicode_prefix +from pyspark.sql.catalog import Catalog +from pyspark.sql.conf import RuntimeConfig from pyspark.sql.dataframe import DataFrame -from pyspark.sql.functions import UserDefinedFunction from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \ _infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string @@ -56,7 +57,6 @@ def _monkey_patch_RDD(sparkSession): RDD.toDF = toDF -# TODO(andrew): implement conf and catalog namespaces class SparkSession(object): """Main entry point for Spark SQL functionality. @@ -105,7 +105,7 @@ class SparkSession(object): @classmethod @since(2.0) def withHiveSupport(cls, sparkContext): - """Returns a new SparkSession with a catalog backed by Hive + """Returns a new SparkSession with a catalog backed by Hive. :param sparkContext: The underlying :class:`SparkContext`. """ @@ -121,6 +121,19 @@ class SparkSession(object): """ return self.__class__(self._sc, self._jsparkSession.newSession()) + @property + @since(2.0) + def conf(self): + """Runtime configuration interface for Spark. + + This is the interface through which the user can get and set all Spark and Hadoop + configurations that are relevant to Spark SQL. When getting the value of a config, + this defaults to the value set in the underlying :class:`SparkContext`, if any. + """ + if not hasattr(self, "_conf"): + self._conf = RuntimeConfig(self._jsparkSession.conf()) + return self._conf + @since(2.0) def setConf(self, key, value): """ @@ -152,12 +165,23 @@ class SparkSession(object): @property @since(2.0) + def catalog(self): + """Interface through which the user may create, drop, alter or query underlying + databases, tables, functions etc. + """ + if not hasattr(self, "_catalog"): + self._catalog = Catalog(self) + return self._catalog + + @property + @since(2.0) def udf(self): """Returns a :class:`UDFRegistration` for UDF registration. :return: :class:`UDFRegistration` """ - return UDFRegistration(self) + from pyspark.sql.context import UDFRegistration + return UDFRegistration(self._wrapped) @since(2.0) def range(self, start, end=None, step=1, numPartitions=None): @@ -190,37 +214,6 @@ class SparkSession(object): return DataFrame(jdf, self._wrapped) - @ignore_unicode_prefix - @since(2.0) - def registerFunction(self, name, f, returnType=StringType()): - """Registers a python function (including lambda function) as a UDF - so it can be used in SQL statements. - - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not given it default to a string and conversion will automatically - be done. For any other return type, the produced object must match the specified type. - - :param name: name of the UDF - :param f: python function - :param returnType: a :class:`DataType` object - - >>> spark.registerFunction("stringLengthString", lambda x: len(x)) - >>> spark.sql("SELECT stringLengthString('test')").collect() - [Row(stringLengthString(test)=u'4')] - - >>> from pyspark.sql.types import IntegerType - >>> spark.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) - >>> spark.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) - >>> spark.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - """ - udf = UserDefinedFunction(f, returnType, name) - self._jsparkSession.udf().registerPython(name, udf._judf) - def _inferSchemaFromList(self, data): """ Infer schema from list of Row or tuple. @@ -443,49 +436,6 @@ class SparkSession(object): df._schema = schema return df - @since(2.0) - def registerDataFrameAsTable(self, df, tableName): - """Registers the given :class:`DataFrame` as a temporary table in the catalog. - - Temporary tables exist only during the lifetime of this instance of :class:`SparkSession`. - - >>> spark.registerDataFrameAsTable(df, "table1") - """ - if (df.__class__ is DataFrame): - self._jsparkSession.registerDataFrameAsTable(df._jdf, tableName) - else: - raise ValueError("Can only register DataFrame as table") - - @since(2.0) - def createExternalTable(self, tableName, path=None, source=None, schema=None, **options): - """Creates an external table based on the dataset in a data source. - - It returns the DataFrame associated with the external table. - - The data source is specified by the ``source`` and a set of ``options``. - If ``source`` is not specified, the default data source configured by - ``spark.sql.sources.default`` will be used. - - Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and - created external table. - - :return: :class:`DataFrame` - """ - if path is not None: - options["path"] = path - if source is None: - source = self.getConf("spark.sql.sources.default", - "org.apache.spark.sql.parquet") - if schema is None: - df = self._jsparkSession.catalog().createExternalTable(tableName, source, options) - else: - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType") - scala_datatype = self._jsparkSession.parseDataType(schema.json()) - df = self._jsparkSession.catalog().createExternalTable( - tableName, source, scala_datatype, options) - return DataFrame(df, self._wrapped) - @ignore_unicode_prefix @since(2.0) def sql(self, sqlQuery): @@ -493,7 +443,7 @@ class SparkSession(object): :return: :class:`DataFrame` - >>> spark.registerDataFrameAsTable(df, "table1") + >>> spark.catalog.registerDataFrameAsTable(df, "table1") >>> df2 = spark.sql("SELECT field1 AS f1, field2 as f2 from table1") >>> df2.collect() [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] @@ -506,7 +456,7 @@ class SparkSession(object): :return: :class:`DataFrame` - >>> spark.registerDataFrameAsTable(df, "table1") + >>> spark.catalog.registerDataFrameAsTable(df, "table1") >>> df2 = spark.table("table1") >>> sorted(df.collect()) == sorted(df2.collect()) True @@ -523,3 +473,32 @@ class SparkSession(object): :return: :class:`DataFrameReader` """ return DataFrameReader(self._wrapped) + + +def _test(): + import os + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row + import pyspark.sql.session + + os.chdir(os.environ["SPARK_HOME"]) + + globs = pyspark.sql.session.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['spark'] = SparkSession(sc) + globs['rdd'] = rdd = sc.parallelize( + [Row(field1=1, field2="row1"), + Row(field1=2, field2="row2"), + Row(field1=3, field2="row3")]) + globs['df'] = rdd.toDF() + (failure_count, test_count) = doctest.testmod( + pyspark.sql.session, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() |