From 89addd40abdacd65cc03ac8aa5f9cf3dd4a4c19b Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 28 Apr 2016 10:55:48 -0700 Subject: [SPARK-14945][PYTHON] SparkSession Python API ## What changes were proposed in this pull request? ``` Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /__ / .__/\_,_/_/ /_/\_\ version 2.0.0-SNAPSHOT /_/ Using Python version 2.7.5 (default, Mar 9 2014 22:15:05) SparkSession available as 'spark'. >>> spark >>> spark.sql("SHOW TABLES").show() ... +---------+-----------+ |tableName|isTemporary| +---------+-----------+ | src| false| +---------+-----------+ >>> spark.range(1, 10, 2).show() +---+ | id| +---+ | 1| | 3| | 5| | 7| | 9| +---+ ``` **Note**: This API is NOT complete in its current state. In particular, for now I left out the `conf` and `catalog` APIs, which were added later in Scala. These will be added later before 2.0. ## How was this patch tested? Python tests. Author: Andrew Or Closes #12746 from andrewor14/python-spark-session. --- python/pyspark/shell.py | 11 +- python/pyspark/sql/__init__.py | 5 +- python/pyspark/sql/context.py | 278 ++++----------------- python/pyspark/sql/readwriter.py | 2 +- python/pyspark/sql/session.py | 525 +++++++++++++++++++++++++++++++++++++++ python/pyspark/sql/tests.py | 4 +- 6 files changed, 585 insertions(+), 240 deletions(-) create mode 100644 python/pyspark/sql/session.py diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 7c37f75193..c6b0eda996 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -29,7 +29,7 @@ import py4j import pyspark from pyspark.context import SparkContext -from pyspark.sql import SQLContext, HiveContext +from pyspark.sql import SparkSession, SQLContext from pyspark.storagelevel import StorageLevel if os.environ.get("SPARK_EXECUTOR_URI"): @@ -41,13 +41,14 @@ atexit.register(lambda: sc.stop()) try: # Try to access HiveConf, it will raise exception if Hive is not added sc._jvm.org.apache.hadoop.hive.conf.HiveConf() - sqlContext = HiveContext(sc) + spark = SparkSession.withHiveSupport(sc) except py4j.protocol.Py4JError: - sqlContext = SQLContext(sc) + spark = SparkSession(sc) except TypeError: - sqlContext = SQLContext(sc) + spark = SparkSession(sc) # for compatibility +sqlContext = spark._wrapped sqlCtx = sqlContext print("""Welcome to @@ -61,7 +62,7 @@ print("Using Python version %s (%s, %s)" % ( platform.python_version(), platform.python_build()[0], platform.python_build()[1])) -print("SparkContext available as sc, %s available as sqlContext." % sqlContext.__class__.__name__) +print("SparkSession available as 'spark'.") # The ./bin/pyspark script stores the old PYTHONSTARTUP value in OLD_PYTHONSTARTUP, # which allows us to execute the user's PYTHONSTARTUP file: diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 0b06c8339f..cff73ff192 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -46,6 +46,7 @@ from __future__ import absolute_import from pyspark.sql.types import Row from pyspark.sql.context import SQLContext, HiveContext +from pyspark.sql.session import SparkSession from pyspark.sql.column import Column from pyspark.sql.dataframe import DataFrame, DataFrameNaFunctions, DataFrameStatFunctions from pyspark.sql.group import GroupedData @@ -54,7 +55,7 @@ from pyspark.sql.window import Window, WindowSpec __all__ = [ - 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', - 'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec', + 'SparkSession', 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', + 'Row', 'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec', 'DataFrameReader', 'DataFrameWriter' ] diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 600a6e0bc2..48ffb59668 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -17,70 +17,37 @@ from __future__ import print_function import sys -import warnings -import json -from functools import reduce if sys.version >= '3': basestring = unicode = str -else: - from itertools import imap as map - -from py4j.protocol import Py4JError from pyspark import since -from pyspark.rdd import RDD, ignore_unicode_prefix -from pyspark.serializers import AutoBatchedSerializer, PickleSerializer -from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \ - _infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string +from pyspark.rdd import ignore_unicode_prefix +from pyspark.sql.session import _monkey_patch_RDD, SparkSession from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader +from pyspark.sql.types import Row, StringType from pyspark.sql.utils import install_exception_handler -from pyspark.sql.functions import UserDefinedFunction - -try: - import pandas - has_pandas = True -except Exception: - has_pandas = False __all__ = ["SQLContext", "HiveContext", "UDFRegistration"] -def _monkey_patch_RDD(sqlContext): - def toDF(self, schema=None, sampleRatio=None): - """ - Converts current :class:`RDD` into a :class:`DataFrame` - - This is a shorthand for ``sqlContext.createDataFrame(rdd, schema, sampleRatio)`` - - :param schema: a StructType or list of names of columns - :param samplingRatio: the sample ratio of rows used for inferring - :return: a DataFrame - - >>> rdd.toDF().collect() - [Row(name=u'Alice', age=1)] - """ - return sqlContext.createDataFrame(self, schema, sampleRatio) - - RDD.toDF = toDF - - class SQLContext(object): - """Main entry point for Spark SQL functionality. + """Wrapper around :class:`SparkSession`, the main entry point to Spark SQL functionality. A SQLContext can be used create :class:`DataFrame`, register :class:`DataFrame` as tables, execute SQL over tables, cache tables, and read parquet files. :param sparkContext: The :class:`SparkContext` backing this SQLContext. - :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instantiate a new + :param sparkSession: The :class:`SparkSession` around which this SQLContext wraps. + :param jsqlContext: An optional JVM Scala SQLContext. If set, we do not instantiate a new SQLContext in the JVM, instead we make all calls to this object. """ _instantiatedContext = None @ignore_unicode_prefix - def __init__(self, sparkContext, sqlContext=None): + def __init__(self, sparkContext, sparkSession=None, jsqlContext=None): """Creates a new SQLContext. >>> from datetime import datetime @@ -100,8 +67,13 @@ class SQLContext(object): self._sc = sparkContext self._jsc = self._sc._jsc self._jvm = self._sc._jvm - self._scala_SQLContext = sqlContext - _monkey_patch_RDD(self) + if sparkSession is None: + sparkSession = SparkSession(sparkContext) + if jsqlContext is None: + jsqlContext = sparkSession._jwrapped + self.sparkSession = sparkSession + self._jsqlContext = jsqlContext + _monkey_patch_RDD(self.sparkSession) install_exception_handler() if SQLContext._instantiatedContext is None: SQLContext._instantiatedContext = self @@ -113,9 +85,7 @@ class SQLContext(object): Subclasses can override this property to provide their own JVM Contexts. """ - if self._scala_SQLContext is None: - self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) - return self._scala_SQLContext + return self._jsqlContext @classmethod @since(1.6) @@ -127,7 +97,8 @@ class SQLContext(object): """ if cls._instantiatedContext is None: jsqlContext = sc._jvm.SQLContext.getOrCreate(sc._jsc.sc()) - cls(sc, jsqlContext) + sparkSession = SparkSession(sc, jsqlContext.sparkSession()) + cls(sc, sparkSession, jsqlContext) return cls._instantiatedContext @since(1.6) @@ -137,14 +108,13 @@ class SQLContext(object): registered temporary tables and UDFs, but shared SparkContext and table cache. """ - jsqlContext = self._ssql_ctx.newSession() - return self.__class__(self._sc, jsqlContext) + return self.__class__(self._sc, self.sparkSession.newSession()) @since(1.3) def setConf(self, key, value): """Sets the given Spark SQL configuration property. """ - self._ssql_ctx.setConf(key, value) + self.sparkSession.setConf(key, value) @ignore_unicode_prefix @since(1.3) @@ -163,10 +133,7 @@ class SQLContext(object): >>> sqlContext.getConf("spark.sql.shuffle.partitions", "10") u'50' """ - if defaultValue is not None: - return self._ssql_ctx.getConf(key, defaultValue) - else: - return self._ssql_ctx.getConf(key) + return self.sparkSession.getConf(key, defaultValue) @property @since("1.3.1") @@ -175,7 +142,7 @@ class SQLContext(object): :return: :class:`UDFRegistration` """ - return UDFRegistration(self) + return UDFRegistration(self.sparkSession) @since(1.4) def range(self, start, end=None, step=1, numPartitions=None): @@ -198,15 +165,7 @@ class SQLContext(object): >>> sqlContext.range(3).collect() [Row(id=0), Row(id=1), Row(id=2)] """ - if numPartitions is None: - numPartitions = self._sc.defaultParallelism - - if end is None: - jdf = self._ssql_ctx.range(0, int(start), int(step), int(numPartitions)) - else: - jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions)) - - return DataFrame(jdf, self) + return self.sparkSession.range(start, end, step, numPartitions) @ignore_unicode_prefix @since(1.2) @@ -236,27 +195,9 @@ class SQLContext(object): >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] """ - udf = UserDefinedFunction(f, returnType, name) - self._ssql_ctx.udf().registerPython(name, udf._judf) - - def _inferSchemaFromList(self, data): - """ - Infer schema from list of Row or tuple. - - :param data: list of Row or tuple - :return: StructType - """ - if not data: - raise ValueError("can not infer schema from empty dataset") - first = data[0] - if type(first) is dict: - warnings.warn("inferring schema from dict is deprecated," - "please use pyspark.sql.Row instead") - schema = reduce(_merge_type, map(_infer_schema, data)) - if _has_nulltype(schema): - raise ValueError("Some of types cannot be determined after inferring") - return schema + self.sparkSession.registerFunction(name, f, returnType) + # TODO(andrew): delete this once we refactor things to take in SparkSession def _inferSchema(self, rdd, samplingRatio=None): """ Infer schema from an RDD of Row or tuple. @@ -265,78 +206,7 @@ class SQLContext(object): :param samplingRatio: sampling ratio, or no sampling (default) :return: StructType """ - first = rdd.first() - if not first: - raise ValueError("The first row in RDD is empty, " - "can not infer schema") - if type(first) is dict: - warnings.warn("Using RDD of dict to inferSchema is deprecated. " - "Use pyspark.sql.Row instead") - - if samplingRatio is None: - schema = _infer_schema(first) - if _has_nulltype(schema): - for row in rdd.take(100)[1:]: - schema = _merge_type(schema, _infer_schema(row)) - if not _has_nulltype(schema): - break - else: - raise ValueError("Some of types cannot be determined by the " - "first 100 rows, please try again with sampling") - else: - if samplingRatio < 0.99: - rdd = rdd.sample(False, float(samplingRatio)) - schema = rdd.map(_infer_schema).reduce(_merge_type) - return schema - - def _createFromRDD(self, rdd, schema, samplingRatio): - """ - Create an RDD for DataFrame from an existing RDD, returns the RDD and schema. - """ - if schema is None or isinstance(schema, (list, tuple)): - struct = self._inferSchema(rdd, samplingRatio) - converter = _create_converter(struct) - rdd = rdd.map(converter) - if isinstance(schema, (list, tuple)): - for i, name in enumerate(schema): - struct.fields[i].name = name - struct.names[i] = name - schema = struct - - elif not isinstance(schema, StructType): - raise TypeError("schema should be StructType or list or None, but got: %s" % schema) - - # convert python objects to sql data - rdd = rdd.map(schema.toInternal) - return rdd, schema - - def _createFromLocal(self, data, schema): - """ - Create an RDD for DataFrame from an list or pandas.DataFrame, returns - the RDD and schema. - """ - # make sure data could consumed multiple times - if not isinstance(data, list): - data = list(data) - - if schema is None or isinstance(schema, (list, tuple)): - struct = self._inferSchemaFromList(data) - if isinstance(schema, (list, tuple)): - for i, name in enumerate(schema): - struct.fields[i].name = name - struct.names[i] = name - schema = struct - - elif isinstance(schema, StructType): - for row in data: - _verify_type(row, schema) - - else: - raise TypeError("schema should be StructType or list or None, but got: %s" % schema) - - # convert python objects to sql data - data = [schema.toInternal(row) for row in data] - return self._sc.parallelize(data), schema + return self.sparkSession._inferSchema(rdd, samplingRatio) @since(1.3) @ignore_unicode_prefix @@ -421,40 +291,7 @@ class SQLContext(object): ... Py4JJavaError: ... """ - if isinstance(data, DataFrame): - raise TypeError("data is already a DataFrame") - - if isinstance(schema, basestring): - schema = _parse_datatype_string(schema) - - if has_pandas and isinstance(data, pandas.DataFrame): - if schema is None: - schema = [str(x) for x in data.columns] - data = [r.tolist() for r in data.to_records(index=False)] - - if isinstance(schema, StructType): - def prepare(obj): - _verify_type(obj, schema) - return obj - elif isinstance(schema, DataType): - datatype = schema - - def prepare(obj): - _verify_type(obj, datatype) - return (obj, ) - schema = StructType().add("value", datatype) - else: - prepare = lambda obj: obj - - if isinstance(data, RDD): - rdd, schema = self._createFromRDD(data.map(prepare), schema, samplingRatio) - else: - rdd, schema = self._createFromLocal(map(prepare, data), schema) - jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) - jdf = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) - df = DataFrame(jdf, self) - df._schema = schema - return df + return self.sparkSession.createDataFrame(data, schema, samplingRatio) @since(1.3) def registerDataFrameAsTable(self, df, tableName): @@ -464,10 +301,7 @@ class SQLContext(object): >>> sqlContext.registerDataFrameAsTable(df, "table1") """ - if (df.__class__ is DataFrame): - self._ssql_ctx.registerDataFrameAsTable(df._jdf, tableName) - else: - raise ValueError("Can only register DataFrame as table") + self.sparkSession.registerDataFrameAsTable(df, tableName) @since(1.6) def dropTempTable(self, tableName): @@ -493,20 +327,7 @@ class SQLContext(object): :return: :class:`DataFrame` """ - if path is not None: - options["path"] = path - if source is None: - source = self.getConf("spark.sql.sources.default", - "org.apache.spark.sql.parquet") - if schema is None: - df = self._ssql_ctx.createExternalTable(tableName, source, options) - else: - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType") - scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype, - options) - return DataFrame(df, self) + return self.sparkSession.createExternalTable(tableName, path, source, schema, **options) @ignore_unicode_prefix @since(1.0) @@ -520,7 +341,7 @@ class SQLContext(object): >>> df2.collect() [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] """ - return DataFrame(self._ssql_ctx.sql(sqlQuery), self) + return self.sparkSession.sql(sqlQuery) @since(1.0) def table(self, tableName): @@ -533,7 +354,7 @@ class SQLContext(object): >>> sorted(df.collect()) == sorted(df2.collect()) True """ - return DataFrame(self._ssql_ctx.table(tableName), self) + return self.sparkSession.table(tableName) @ignore_unicode_prefix @since(1.3) @@ -603,7 +424,7 @@ class SQLContext(object): return DataFrameReader(self) -# TODO(andrew): remove this too +# TODO(andrew): deprecate this class HiveContext(SQLContext): """A variant of Spark SQL that integrates with data stored in Hive. @@ -611,29 +432,28 @@ class HiveContext(SQLContext): It supports running both SQL and HiveQL commands. :param sparkContext: The SparkContext to wrap. - :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instantiate a new + :param jhiveContext: An optional JVM Scala HiveContext. If set, we do not instantiate a new :class:`HiveContext` in the JVM, instead we make all calls to this object. """ - def __init__(self, sparkContext, hiveContext=None): - SQLContext.__init__(self, sparkContext) - if hiveContext: - self._scala_HiveContext = hiveContext + def __init__(self, sparkContext, jhiveContext=None): + if jhiveContext is None: + sparkSession = SparkSession.withHiveSupport(sparkContext) + else: + sparkSession = SparkSession(sparkContext, jhiveContext.sparkSession()) + SQLContext.__init__(self, sparkContext, sparkSession, jhiveContext) + + @classmethod + def _createForTesting(cls, sparkContext): + """(Internal use only) Create a new HiveContext for testing. - @property - def _ssql_ctx(self): - try: - if not hasattr(self, '_scala_HiveContext'): - self._scala_HiveContext = self._get_hive_ctx() - return self._scala_HiveContext - except Py4JError as e: - print("You must build Spark with Hive. " - "Export 'SPARK_HIVE=true' and run " - "build/sbt assembly", file=sys.stderr) - raise - - def _get_hive_ctx(self): - return self._jvm.SparkSession.withHiveSupport(self._jsc.sc()).wrapped() + All test code that touches HiveContext *must* go through this method. Otherwise, + you may end up launching multiple derby instances and encounter with incredibly + confusing error messages. + """ + jsc = sparkContext._jsc.sc() + jtestHive = sparkContext._jvm.org.apache.spark.sql.hive.test.TestHiveContext(jsc) + return cls(sparkContext, jtestHive) def refreshTable(self, tableName): """Invalidate and refresh all the cached the metadata of the given diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index e39cf1ae03..784609e4c5 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -743,7 +743,7 @@ def _test(): globs['os'] = os globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) - globs['hiveContext'] = HiveContext(sc) + globs['hiveContext'] = HiveContext._createForTesting(sc) globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned') globs['sdf'] =\ globs['sqlContext'].read.format('text').stream('python/test_support/sql/streaming') diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py new file mode 100644 index 0000000000..d3355f9da7 --- /dev/null +++ b/python/pyspark/sql/session.py @@ -0,0 +1,525 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function +import sys +import warnings +from functools import reduce + +if sys.version >= '3': + basestring = unicode = str +else: + from itertools import imap as map + +from pyspark import since +from pyspark.rdd import RDD, ignore_unicode_prefix +from pyspark.sql.dataframe import DataFrame +from pyspark.sql.functions import UserDefinedFunction +from pyspark.sql.readwriter import DataFrameReader +from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \ + _infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string +from pyspark.sql.utils import install_exception_handler + +__all__ = ["SparkSession"] + + +def _monkey_patch_RDD(sparkSession): + def toDF(self, schema=None, sampleRatio=None): + """ + Converts current :class:`RDD` into a :class:`DataFrame` + + This is a shorthand for ``spark.createDataFrame(rdd, schema, sampleRatio)`` + + :param schema: a StructType or list of names of columns + :param samplingRatio: the sample ratio of rows used for inferring + :return: a DataFrame + + >>> rdd.toDF().collect() + [Row(name=u'Alice', age=1)] + """ + return sparkSession.createDataFrame(self, schema, sampleRatio) + + RDD.toDF = toDF + + +# TODO(andrew): implement conf and catalog namespaces +class SparkSession(object): + """Main entry point for Spark SQL functionality. + + A SparkSession can be used create :class:`DataFrame`, register :class:`DataFrame` as + tables, execute SQL over tables, cache tables, and read parquet files. + + :param sparkContext: The :class:`SparkContext` backing this SparkSession. + :param jsparkSession: An optional JVM Scala SparkSession. If set, we do not instantiate a new + SparkSession in the JVM, instead we make all calls to this object. + """ + + _instantiatedContext = None + + @ignore_unicode_prefix + def __init__(self, sparkContext, jsparkSession=None): + """Creates a new SparkSession. + + >>> from datetime import datetime + >>> spark = SparkSession(sc) + >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1, + ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), + ... time=datetime(2014, 8, 1, 14, 1, 5))]) + >>> df = allTypes.toDF() + >>> df.registerTempTable("allTypes") + >>> spark.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' + ... 'from allTypes where b and i > 0').collect() + [Row((i + CAST(1 AS BIGINT))=2, (d + CAST(1 AS DOUBLE))=2.0, (NOT b)=False, list[1]=2, \ + dict[s]=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)] + >>> df.rdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect() + [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] + """ + from pyspark.sql.context import SQLContext + self._sc = sparkContext + self._jsc = self._sc._jsc + self._jvm = self._sc._jvm + if jsparkSession is None: + jsparkSession = self._jvm.SparkSession(self._jsc.sc()) + self._jsparkSession = jsparkSession + self._jwrapped = self._jsparkSession.wrapped() + self._wrapped = SQLContext(self._sc, self, self._jwrapped) + _monkey_patch_RDD(self) + install_exception_handler() + if SparkSession._instantiatedContext is None: + SparkSession._instantiatedContext = self + + @classmethod + @since(2.0) + def withHiveSupport(cls, sparkContext): + """Returns a new SparkSession with a catalog backed by Hive + + :param sparkContext: The underlying :class:`SparkContext`. + """ + jsparkSession = sparkContext._jvm.SparkSession.withHiveSupport(sparkContext._jsc.sc()) + return cls(sparkContext, jsparkSession) + + @since(2.0) + def newSession(self): + """ + Returns a new SparkSession as new session, that has separate SQLConf, + registered temporary tables and UDFs, but shared SparkContext and + table cache. + """ + return self.__class__(self._sc, self._jsparkSession.newSession()) + + @since(2.0) + def setConf(self, key, value): + """ + Sets the given Spark SQL configuration property. + """ + self._jsparkSession.setConf(key, value) + + @ignore_unicode_prefix + @since(2.0) + def getConf(self, key, defaultValue=None): + """Returns the value of Spark SQL configuration property for the given key. + + If the key is not set and defaultValue is not None, return + defaultValue. If the key is not set and defaultValue is None, return + the system default value. + + >>> spark.getConf("spark.sql.shuffle.partitions") + u'200' + >>> spark.getConf("spark.sql.shuffle.partitions", "10") + u'10' + >>> spark.setConf("spark.sql.shuffle.partitions", "50") + >>> spark.getConf("spark.sql.shuffle.partitions", "10") + u'50' + """ + if defaultValue is not None: + return self._jsparkSession.getConf(key, defaultValue) + else: + return self._jsparkSession.getConf(key) + + @property + @since(2.0) + def udf(self): + """Returns a :class:`UDFRegistration` for UDF registration. + + :return: :class:`UDFRegistration` + """ + return UDFRegistration(self) + + @since(2.0) + def range(self, start, end=None, step=1, numPartitions=None): + """ + Create a :class:`DataFrame` with single LongType column named `id`, + containing elements in a range from `start` to `end` (exclusive) with + step value `step`. + + :param start: the start value + :param end: the end value (exclusive) + :param step: the incremental step (default: 1) + :param numPartitions: the number of partitions of the DataFrame + :return: :class:`DataFrame` + + >>> spark.range(1, 7, 2).collect() + [Row(id=1), Row(id=3), Row(id=5)] + + If only one argument is specified, it will be used as the end value. + + >>> spark.range(3).collect() + [Row(id=0), Row(id=1), Row(id=2)] + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + + if end is None: + jdf = self._jsparkSession.range(0, int(start), int(step), int(numPartitions)) + else: + jdf = self._jsparkSession.range(int(start), int(end), int(step), int(numPartitions)) + + return DataFrame(jdf, self._wrapped) + + @ignore_unicode_prefix + @since(2.0) + def registerFunction(self, name, f, returnType=StringType()): + """Registers a python function (including lambda function) as a UDF + so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not given it default to a string and conversion will automatically + be done. For any other return type, the produced object must match the specified type. + + :param name: name of the UDF + :param f: python function + :param returnType: a :class:`DataType` object + + >>> spark.registerFunction("stringLengthString", lambda x: len(x)) + >>> spark.sql("SELECT stringLengthString('test')").collect() + [Row(stringLengthString(test)=u'4')] + + >>> from pyspark.sql.types import IntegerType + >>> spark.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> spark.sql("SELECT stringLengthInt('test')").collect() + [Row(stringLengthInt(test)=4)] + + >>> from pyspark.sql.types import IntegerType + >>> spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> spark.sql("SELECT stringLengthInt('test')").collect() + [Row(stringLengthInt(test)=4)] + """ + udf = UserDefinedFunction(f, returnType, name) + self._jsparkSession.udf().registerPython(name, udf._judf) + + def _inferSchemaFromList(self, data): + """ + Infer schema from list of Row or tuple. + + :param data: list of Row or tuple + :return: StructType + """ + if not data: + raise ValueError("can not infer schema from empty dataset") + first = data[0] + if type(first) is dict: + warnings.warn("inferring schema from dict is deprecated," + "please use pyspark.sql.Row instead") + schema = reduce(_merge_type, map(_infer_schema, data)) + if _has_nulltype(schema): + raise ValueError("Some of types cannot be determined after inferring") + return schema + + def _inferSchema(self, rdd, samplingRatio=None): + """ + Infer schema from an RDD of Row or tuple. + + :param rdd: an RDD of Row or tuple + :param samplingRatio: sampling ratio, or no sampling (default) + :return: StructType + """ + first = rdd.first() + if not first: + raise ValueError("The first row in RDD is empty, " + "can not infer schema") + if type(first) is dict: + warnings.warn("Using RDD of dict to inferSchema is deprecated. " + "Use pyspark.sql.Row instead") + + if samplingRatio is None: + schema = _infer_schema(first) + if _has_nulltype(schema): + for row in rdd.take(100)[1:]: + schema = _merge_type(schema, _infer_schema(row)) + if not _has_nulltype(schema): + break + else: + raise ValueError("Some of types cannot be determined by the " + "first 100 rows, please try again with sampling") + else: + if samplingRatio < 0.99: + rdd = rdd.sample(False, float(samplingRatio)) + schema = rdd.map(_infer_schema).reduce(_merge_type) + return schema + + def _createFromRDD(self, rdd, schema, samplingRatio): + """ + Create an RDD for DataFrame from an existing RDD, returns the RDD and schema. + """ + if schema is None or isinstance(schema, (list, tuple)): + struct = self._inferSchema(rdd, samplingRatio) + converter = _create_converter(struct) + rdd = rdd.map(converter) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + struct.names[i] = name + schema = struct + + elif not isinstance(schema, StructType): + raise TypeError("schema should be StructType or list or None, but got: %s" % schema) + + # convert python objects to sql data + rdd = rdd.map(schema.toInternal) + return rdd, schema + + def _createFromLocal(self, data, schema): + """ + Create an RDD for DataFrame from an list or pandas.DataFrame, returns + the RDD and schema. + """ + # make sure data could consumed multiple times + if not isinstance(data, list): + data = list(data) + + if schema is None or isinstance(schema, (list, tuple)): + struct = self._inferSchemaFromList(data) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + struct.names[i] = name + schema = struct + + elif isinstance(schema, StructType): + for row in data: + _verify_type(row, schema) + + else: + raise TypeError("schema should be StructType or list or None, but got: %s" % schema) + + # convert python objects to sql data + data = [schema.toInternal(row) for row in data] + return self._sc.parallelize(data), schema + + @since(2.0) + @ignore_unicode_prefix + def createDataFrame(self, data, schema=None, samplingRatio=None): + """ + Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`. + + When ``schema`` is a list of column names, the type of each column + will be inferred from ``data``. + + When ``schema`` is ``None``, it will try to infer the schema (column names and types) + from ``data``, which should be an RDD of :class:`Row`, + or :class:`namedtuple`, or :class:`dict`. + + When ``schema`` is :class:`DataType` or datatype string, it must match the real data, or + exception will be thrown at runtime. If the given schema is not StructType, it will be + wrapped into a StructType as its only field, and the field name will be "value", each record + will also be wrapped into a tuple, which can be converted to row later. + + If schema inference is needed, ``samplingRatio`` is used to determined the ratio of + rows used for schema inference. The first row will be used if ``samplingRatio`` is ``None``. + + :param data: an RDD of any kind of SQL data representation(e.g. row, tuple, int, boolean, + etc.), or :class:`list`, or :class:`pandas.DataFrame`. + :param schema: a :class:`DataType` or a datatype string or a list of column names, default + is None. The data type string format equals to `DataType.simpleString`, except that + top level struct type can omit the `struct<>` and atomic types use `typeName()` as + their format, e.g. use `byte` instead of `tinyint` for ByteType. We can also use `int` + as a short name for IntegerType. + :param samplingRatio: the sample ratio of rows used for inferring + :return: :class:`DataFrame` + + .. versionchanged:: 2.0 + The schema parameter can be a DataType or a datatype string after 2.0. If it's not a + StructType, it will be wrapped into a StructType and each record will also be wrapped + into a tuple. + + >>> l = [('Alice', 1)] + >>> spark.createDataFrame(l).collect() + [Row(_1=u'Alice', _2=1)] + >>> spark.createDataFrame(l, ['name', 'age']).collect() + [Row(name=u'Alice', age=1)] + + >>> d = [{'name': 'Alice', 'age': 1}] + >>> spark.createDataFrame(d).collect() + [Row(age=1, name=u'Alice')] + + >>> rdd = sc.parallelize(l) + >>> spark.createDataFrame(rdd).collect() + [Row(_1=u'Alice', _2=1)] + >>> df = spark.createDataFrame(rdd, ['name', 'age']) + >>> df.collect() + [Row(name=u'Alice', age=1)] + + >>> from pyspark.sql import Row + >>> Person = Row('name', 'age') + >>> person = rdd.map(lambda r: Person(*r)) + >>> df2 = spark.createDataFrame(person) + >>> df2.collect() + [Row(name=u'Alice', age=1)] + + >>> from pyspark.sql.types import * + >>> schema = StructType([ + ... StructField("name", StringType(), True), + ... StructField("age", IntegerType(), True)]) + >>> df3 = spark.createDataFrame(rdd, schema) + >>> df3.collect() + [Row(name=u'Alice', age=1)] + + >>> spark.createDataFrame(df.toPandas()).collect() # doctest: +SKIP + [Row(name=u'Alice', age=1)] + >>> spark.createDataFrame(pandas.DataFrame([[1, 2]])).collect() # doctest: +SKIP + [Row(0=1, 1=2)] + + >>> spark.createDataFrame(rdd, "a: string, b: int").collect() + [Row(a=u'Alice', b=1)] + >>> rdd = rdd.map(lambda row: row[1]) + >>> spark.createDataFrame(rdd, "int").collect() + [Row(value=1)] + >>> spark.createDataFrame(rdd, "boolean").collect() # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + Py4JJavaError: ... + """ + if isinstance(data, DataFrame): + raise TypeError("data is already a DataFrame") + + if isinstance(schema, basestring): + schema = _parse_datatype_string(schema) + + try: + import pandas + has_pandas = True + except Exception: + has_pandas = False + if has_pandas and isinstance(data, pandas.DataFrame): + if schema is None: + schema = [str(x) for x in data.columns] + data = [r.tolist() for r in data.to_records(index=False)] + + if isinstance(schema, StructType): + def prepare(obj): + _verify_type(obj, schema) + return obj + elif isinstance(schema, DataType): + datatype = schema + + def prepare(obj): + _verify_type(obj, datatype) + return (obj, ) + schema = StructType().add("value", datatype) + else: + prepare = lambda obj: obj + + if isinstance(data, RDD): + rdd, schema = self._createFromRDD(data.map(prepare), schema, samplingRatio) + else: + rdd, schema = self._createFromLocal(map(prepare, data), schema) + jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) + jdf = self._jsparkSession.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) + df = DataFrame(jdf, self._wrapped) + df._schema = schema + return df + + @since(2.0) + def registerDataFrameAsTable(self, df, tableName): + """Registers the given :class:`DataFrame` as a temporary table in the catalog. + + Temporary tables exist only during the lifetime of this instance of :class:`SparkSession`. + + >>> spark.registerDataFrameAsTable(df, "table1") + """ + if (df.__class__ is DataFrame): + self._jsparkSession.registerDataFrameAsTable(df._jdf, tableName) + else: + raise ValueError("Can only register DataFrame as table") + + @since(2.0) + def createExternalTable(self, tableName, path=None, source=None, schema=None, **options): + """Creates an external table based on the dataset in a data source. + + It returns the DataFrame associated with the external table. + + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. + + Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and + created external table. + + :return: :class:`DataFrame` + """ + if path is not None: + options["path"] = path + if source is None: + source = self.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + if schema is None: + df = self._jsparkSession.catalog().createExternalTable(tableName, source, options) + else: + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + scala_datatype = self._jsparkSession.parseDataType(schema.json()) + df = self._jsparkSession.catalog().createExternalTable( + tableName, source, scala_datatype, options) + return DataFrame(df, self._wrapped) + + @ignore_unicode_prefix + @since(2.0) + def sql(self, sqlQuery): + """Returns a :class:`DataFrame` representing the result of the given query. + + :return: :class:`DataFrame` + + >>> spark.registerDataFrameAsTable(df, "table1") + >>> df2 = spark.sql("SELECT field1 AS f1, field2 as f2 from table1") + >>> df2.collect() + [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] + """ + return DataFrame(self._jsparkSession.sql(sqlQuery), self._wrapped) + + @since(2.0) + def table(self, tableName): + """Returns the specified table as a :class:`DataFrame`. + + :return: :class:`DataFrame` + + >>> spark.registerDataFrameAsTable(df, "table1") + >>> df2 = spark.table("table1") + >>> sorted(df.collect()) == sorted(df2.collect()) + True + """ + return DataFrame(self._jsparkSession.table(tableName), self._wrapped) + + @property + @since(2.0) + def read(self): + """ + Returns a :class:`DataFrameReader` that can be used to read data + in as a :class:`DataFrame`. + + :return: :class:`DataFrameReader` + """ + return DataFrameReader(self._wrapped) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 42e283073f..99a12d639a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1369,9 +1369,7 @@ class HiveContextSQLTests(ReusedPySparkTestCase): cls.tearDownClass() raise unittest.SkipTest("Hive is not available") os.unlink(cls.tempdir.name) - _scala_HiveContext =\ - cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc()) - cls.sqlCtx = HiveContext(cls.sc, _scala_HiveContext) + cls.sqlCtx = HiveContext._createForTesting(cls.sc) cls.testData = [Row(key=i, value=str(i)) for i in range(100)] cls.df = cls.sc.parallelize(cls.testData).toDF() -- cgit v1.2.3