aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/session.py
diff options
context:
space:
mode:
authorAndrew Or <andrew@databricks.com>2016-04-29 09:34:10 -0700
committerAndrew Or <andrew@databricks.com>2016-04-29 09:34:10 -0700
commita7d0fedc940721d09350f2e57ae85591e0a3d90e (patch)
tree5f96e980e810cd13f36658ed052a1e987c5d261c /python/pyspark/sql/session.py
parent7feeb82cb7f462e44f7e698c7c3b6ac3a77aade4 (diff)
downloadspark-a7d0fedc940721d09350f2e57ae85591e0a3d90e.tar.gz
spark-a7d0fedc940721d09350f2e57ae85591e0a3d90e.tar.bz2
spark-a7d0fedc940721d09350f2e57ae85591e0a3d90e.zip
[SPARK-14988][PYTHON] SparkSession catalog and conf API
## What changes were proposed in this pull request? The `catalog` and `conf` APIs were exposed in `SparkSession` in #12713 and #12669. This patch adds those to the python API. ## How was this patch tested? Python tests. Author: Andrew Or <andrew@databricks.com> Closes #12765 from andrewor14/python-spark-session-more.
Diffstat (limited to 'python/pyspark/sql/session.py')
-rw-r--r--python/pyspark/sql/session.py139
1 files changed, 59 insertions, 80 deletions
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index d3355f9da7..b3bc8961b8 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -27,8 +27,9 @@ else:
from pyspark import since
from pyspark.rdd import RDD, ignore_unicode_prefix
+from pyspark.sql.catalog import Catalog
+from pyspark.sql.conf import RuntimeConfig
from pyspark.sql.dataframe import DataFrame
-from pyspark.sql.functions import UserDefinedFunction
from pyspark.sql.readwriter import DataFrameReader
from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \
_infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string
@@ -56,7 +57,6 @@ def _monkey_patch_RDD(sparkSession):
RDD.toDF = toDF
-# TODO(andrew): implement conf and catalog namespaces
class SparkSession(object):
"""Main entry point for Spark SQL functionality.
@@ -105,7 +105,7 @@ class SparkSession(object):
@classmethod
@since(2.0)
def withHiveSupport(cls, sparkContext):
- """Returns a new SparkSession with a catalog backed by Hive
+ """Returns a new SparkSession with a catalog backed by Hive.
:param sparkContext: The underlying :class:`SparkContext`.
"""
@@ -121,6 +121,19 @@ class SparkSession(object):
"""
return self.__class__(self._sc, self._jsparkSession.newSession())
+ @property
+ @since(2.0)
+ def conf(self):
+ """Runtime configuration interface for Spark.
+
+ This is the interface through which the user can get and set all Spark and Hadoop
+ configurations that are relevant to Spark SQL. When getting the value of a config,
+ this defaults to the value set in the underlying :class:`SparkContext`, if any.
+ """
+ if not hasattr(self, "_conf"):
+ self._conf = RuntimeConfig(self._jsparkSession.conf())
+ return self._conf
+
@since(2.0)
def setConf(self, key, value):
"""
@@ -152,12 +165,23 @@ class SparkSession(object):
@property
@since(2.0)
+ def catalog(self):
+ """Interface through which the user may create, drop, alter or query underlying
+ databases, tables, functions etc.
+ """
+ if not hasattr(self, "_catalog"):
+ self._catalog = Catalog(self)
+ return self._catalog
+
+ @property
+ @since(2.0)
def udf(self):
"""Returns a :class:`UDFRegistration` for UDF registration.
:return: :class:`UDFRegistration`
"""
- return UDFRegistration(self)
+ from pyspark.sql.context import UDFRegistration
+ return UDFRegistration(self._wrapped)
@since(2.0)
def range(self, start, end=None, step=1, numPartitions=None):
@@ -190,37 +214,6 @@ class SparkSession(object):
return DataFrame(jdf, self._wrapped)
- @ignore_unicode_prefix
- @since(2.0)
- def registerFunction(self, name, f, returnType=StringType()):
- """Registers a python function (including lambda function) as a UDF
- so it can be used in SQL statements.
-
- In addition to a name and the function itself, the return type can be optionally specified.
- When the return type is not given it default to a string and conversion will automatically
- be done. For any other return type, the produced object must match the specified type.
-
- :param name: name of the UDF
- :param f: python function
- :param returnType: a :class:`DataType` object
-
- >>> spark.registerFunction("stringLengthString", lambda x: len(x))
- >>> spark.sql("SELECT stringLengthString('test')").collect()
- [Row(stringLengthString(test)=u'4')]
-
- >>> from pyspark.sql.types import IntegerType
- >>> spark.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
- >>> spark.sql("SELECT stringLengthInt('test')").collect()
- [Row(stringLengthInt(test)=4)]
-
- >>> from pyspark.sql.types import IntegerType
- >>> spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
- >>> spark.sql("SELECT stringLengthInt('test')").collect()
- [Row(stringLengthInt(test)=4)]
- """
- udf = UserDefinedFunction(f, returnType, name)
- self._jsparkSession.udf().registerPython(name, udf._judf)
-
def _inferSchemaFromList(self, data):
"""
Infer schema from list of Row or tuple.
@@ -443,49 +436,6 @@ class SparkSession(object):
df._schema = schema
return df
- @since(2.0)
- def registerDataFrameAsTable(self, df, tableName):
- """Registers the given :class:`DataFrame` as a temporary table in the catalog.
-
- Temporary tables exist only during the lifetime of this instance of :class:`SparkSession`.
-
- >>> spark.registerDataFrameAsTable(df, "table1")
- """
- if (df.__class__ is DataFrame):
- self._jsparkSession.registerDataFrameAsTable(df._jdf, tableName)
- else:
- raise ValueError("Can only register DataFrame as table")
-
- @since(2.0)
- def createExternalTable(self, tableName, path=None, source=None, schema=None, **options):
- """Creates an external table based on the dataset in a data source.
-
- It returns the DataFrame associated with the external table.
-
- The data source is specified by the ``source`` and a set of ``options``.
- If ``source`` is not specified, the default data source configured by
- ``spark.sql.sources.default`` will be used.
-
- Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and
- created external table.
-
- :return: :class:`DataFrame`
- """
- if path is not None:
- options["path"] = path
- if source is None:
- source = self.getConf("spark.sql.sources.default",
- "org.apache.spark.sql.parquet")
- if schema is None:
- df = self._jsparkSession.catalog().createExternalTable(tableName, source, options)
- else:
- if not isinstance(schema, StructType):
- raise TypeError("schema should be StructType")
- scala_datatype = self._jsparkSession.parseDataType(schema.json())
- df = self._jsparkSession.catalog().createExternalTable(
- tableName, source, scala_datatype, options)
- return DataFrame(df, self._wrapped)
-
@ignore_unicode_prefix
@since(2.0)
def sql(self, sqlQuery):
@@ -493,7 +443,7 @@ class SparkSession(object):
:return: :class:`DataFrame`
- >>> spark.registerDataFrameAsTable(df, "table1")
+ >>> spark.catalog.registerDataFrameAsTable(df, "table1")
>>> df2 = spark.sql("SELECT field1 AS f1, field2 as f2 from table1")
>>> df2.collect()
[Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
@@ -506,7 +456,7 @@ class SparkSession(object):
:return: :class:`DataFrame`
- >>> spark.registerDataFrameAsTable(df, "table1")
+ >>> spark.catalog.registerDataFrameAsTable(df, "table1")
>>> df2 = spark.table("table1")
>>> sorted(df.collect()) == sorted(df2.collect())
True
@@ -523,3 +473,32 @@ class SparkSession(object):
:return: :class:`DataFrameReader`
"""
return DataFrameReader(self._wrapped)
+
+
+def _test():
+ import os
+ import doctest
+ from pyspark.context import SparkContext
+ from pyspark.sql import Row
+ import pyspark.sql.session
+
+ os.chdir(os.environ["SPARK_HOME"])
+
+ globs = pyspark.sql.session.__dict__.copy()
+ sc = SparkContext('local[4]', 'PythonTest')
+ globs['sc'] = sc
+ globs['spark'] = SparkSession(sc)
+ globs['rdd'] = rdd = sc.parallelize(
+ [Row(field1=1, field2="row1"),
+ Row(field1=2, field2="row2"),
+ Row(field1=3, field2="row3")])
+ globs['df'] = rdd.toDF()
+ (failure_count, test_count) = doctest.testmod(
+ pyspark.sql.session, globs=globs,
+ optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+if __name__ == "__main__":
+ _test()