diff options
author | Andrew Or <andrew@databricks.com> | 2016-04-28 10:55:48 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-04-28 10:55:48 -0700 |
commit | 89addd40abdacd65cc03ac8aa5f9cf3dd4a4c19b (patch) | |
tree | 5ecd3d9a736333c7951de6159eefef86129e3744 /python/pyspark/sql/session.py | |
parent | 5743352a28fffbfbaca2201208ce7a1d7893f813 (diff) | |
download | spark-89addd40abdacd65cc03ac8aa5f9cf3dd4a4c19b.tar.gz spark-89addd40abdacd65cc03ac8aa5f9cf3dd4a4c19b.tar.bz2 spark-89addd40abdacd65cc03ac8aa5f9cf3dd4a4c19b.zip |
[SPARK-14945][PYTHON] SparkSession Python API
## What changes were proposed in this pull request?
```
Welcome to
____ __
/ __/__ ___ _____/ /__
_\ \/ _ \/ _ `/ __/ '_/
/__ / .__/\_,_/_/ /_/\_\ version 2.0.0-SNAPSHOT
/_/
Using Python version 2.7.5 (default, Mar 9 2014 22:15:05)
SparkSession available as 'spark'.
>>> spark
<pyspark.sql.session.SparkSession object at 0x101f3bfd0>
>>> spark.sql("SHOW TABLES").show()
...
+---------+-----------+
|tableName|isTemporary|
+---------+-----------+
| src| false|
+---------+-----------+
>>> spark.range(1, 10, 2).show()
+---+
| id|
+---+
| 1|
| 3|
| 5|
| 7|
| 9|
+---+
```
**Note**: This API is NOT complete in its current state. In particular, for now I left out the `conf` and `catalog` APIs, which were added later in Scala. These will be added later before 2.0.
## How was this patch tested?
Python tests.
Author: Andrew Or <andrew@databricks.com>
Closes #12746 from andrewor14/python-spark-session.
Diffstat (limited to 'python/pyspark/sql/session.py')
-rw-r--r-- | python/pyspark/sql/session.py | 525 |
1 files changed, 525 insertions, 0 deletions
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py new file mode 100644 index 0000000000..d3355f9da7 --- /dev/null +++ b/python/pyspark/sql/session.py @@ -0,0 +1,525 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function +import sys +import warnings +from functools import reduce + +if sys.version >= '3': + basestring = unicode = str +else: + from itertools import imap as map + +from pyspark import since +from pyspark.rdd import RDD, ignore_unicode_prefix +from pyspark.sql.dataframe import DataFrame +from pyspark.sql.functions import UserDefinedFunction +from pyspark.sql.readwriter import DataFrameReader +from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \ + _infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string +from pyspark.sql.utils import install_exception_handler + +__all__ = ["SparkSession"] + + +def _monkey_patch_RDD(sparkSession): + def toDF(self, schema=None, sampleRatio=None): + """ + Converts current :class:`RDD` into a :class:`DataFrame` + + This is a shorthand for ``spark.createDataFrame(rdd, schema, sampleRatio)`` + + :param schema: a StructType or list of names of columns + :param samplingRatio: the sample ratio of rows used for inferring + :return: a DataFrame + + >>> rdd.toDF().collect() + [Row(name=u'Alice', age=1)] + """ + return sparkSession.createDataFrame(self, schema, sampleRatio) + + RDD.toDF = toDF + + +# TODO(andrew): implement conf and catalog namespaces +class SparkSession(object): + """Main entry point for Spark SQL functionality. + + A SparkSession can be used create :class:`DataFrame`, register :class:`DataFrame` as + tables, execute SQL over tables, cache tables, and read parquet files. + + :param sparkContext: The :class:`SparkContext` backing this SparkSession. + :param jsparkSession: An optional JVM Scala SparkSession. If set, we do not instantiate a new + SparkSession in the JVM, instead we make all calls to this object. + """ + + _instantiatedContext = None + + @ignore_unicode_prefix + def __init__(self, sparkContext, jsparkSession=None): + """Creates a new SparkSession. + + >>> from datetime import datetime + >>> spark = SparkSession(sc) + >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1, + ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), + ... time=datetime(2014, 8, 1, 14, 1, 5))]) + >>> df = allTypes.toDF() + >>> df.registerTempTable("allTypes") + >>> spark.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' + ... 'from allTypes where b and i > 0').collect() + [Row((i + CAST(1 AS BIGINT))=2, (d + CAST(1 AS DOUBLE))=2.0, (NOT b)=False, list[1]=2, \ + dict[s]=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)] + >>> df.rdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect() + [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] + """ + from pyspark.sql.context import SQLContext + self._sc = sparkContext + self._jsc = self._sc._jsc + self._jvm = self._sc._jvm + if jsparkSession is None: + jsparkSession = self._jvm.SparkSession(self._jsc.sc()) + self._jsparkSession = jsparkSession + self._jwrapped = self._jsparkSession.wrapped() + self._wrapped = SQLContext(self._sc, self, self._jwrapped) + _monkey_patch_RDD(self) + install_exception_handler() + if SparkSession._instantiatedContext is None: + SparkSession._instantiatedContext = self + + @classmethod + @since(2.0) + def withHiveSupport(cls, sparkContext): + """Returns a new SparkSession with a catalog backed by Hive + + :param sparkContext: The underlying :class:`SparkContext`. + """ + jsparkSession = sparkContext._jvm.SparkSession.withHiveSupport(sparkContext._jsc.sc()) + return cls(sparkContext, jsparkSession) + + @since(2.0) + def newSession(self): + """ + Returns a new SparkSession as new session, that has separate SQLConf, + registered temporary tables and UDFs, but shared SparkContext and + table cache. + """ + return self.__class__(self._sc, self._jsparkSession.newSession()) + + @since(2.0) + def setConf(self, key, value): + """ + Sets the given Spark SQL configuration property. + """ + self._jsparkSession.setConf(key, value) + + @ignore_unicode_prefix + @since(2.0) + def getConf(self, key, defaultValue=None): + """Returns the value of Spark SQL configuration property for the given key. + + If the key is not set and defaultValue is not None, return + defaultValue. If the key is not set and defaultValue is None, return + the system default value. + + >>> spark.getConf("spark.sql.shuffle.partitions") + u'200' + >>> spark.getConf("spark.sql.shuffle.partitions", "10") + u'10' + >>> spark.setConf("spark.sql.shuffle.partitions", "50") + >>> spark.getConf("spark.sql.shuffle.partitions", "10") + u'50' + """ + if defaultValue is not None: + return self._jsparkSession.getConf(key, defaultValue) + else: + return self._jsparkSession.getConf(key) + + @property + @since(2.0) + def udf(self): + """Returns a :class:`UDFRegistration` for UDF registration. + + :return: :class:`UDFRegistration` + """ + return UDFRegistration(self) + + @since(2.0) + def range(self, start, end=None, step=1, numPartitions=None): + """ + Create a :class:`DataFrame` with single LongType column named `id`, + containing elements in a range from `start` to `end` (exclusive) with + step value `step`. + + :param start: the start value + :param end: the end value (exclusive) + :param step: the incremental step (default: 1) + :param numPartitions: the number of partitions of the DataFrame + :return: :class:`DataFrame` + + >>> spark.range(1, 7, 2).collect() + [Row(id=1), Row(id=3), Row(id=5)] + + If only one argument is specified, it will be used as the end value. + + >>> spark.range(3).collect() + [Row(id=0), Row(id=1), Row(id=2)] + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + + if end is None: + jdf = self._jsparkSession.range(0, int(start), int(step), int(numPartitions)) + else: + jdf = self._jsparkSession.range(int(start), int(end), int(step), int(numPartitions)) + + return DataFrame(jdf, self._wrapped) + + @ignore_unicode_prefix + @since(2.0) + def registerFunction(self, name, f, returnType=StringType()): + """Registers a python function (including lambda function) as a UDF + so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not given it default to a string and conversion will automatically + be done. For any other return type, the produced object must match the specified type. + + :param name: name of the UDF + :param f: python function + :param returnType: a :class:`DataType` object + + >>> spark.registerFunction("stringLengthString", lambda x: len(x)) + >>> spark.sql("SELECT stringLengthString('test')").collect() + [Row(stringLengthString(test)=u'4')] + + >>> from pyspark.sql.types import IntegerType + >>> spark.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> spark.sql("SELECT stringLengthInt('test')").collect() + [Row(stringLengthInt(test)=4)] + + >>> from pyspark.sql.types import IntegerType + >>> spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> spark.sql("SELECT stringLengthInt('test')").collect() + [Row(stringLengthInt(test)=4)] + """ + udf = UserDefinedFunction(f, returnType, name) + self._jsparkSession.udf().registerPython(name, udf._judf) + + def _inferSchemaFromList(self, data): + """ + Infer schema from list of Row or tuple. + + :param data: list of Row or tuple + :return: StructType + """ + if not data: + raise ValueError("can not infer schema from empty dataset") + first = data[0] + if type(first) is dict: + warnings.warn("inferring schema from dict is deprecated," + "please use pyspark.sql.Row instead") + schema = reduce(_merge_type, map(_infer_schema, data)) + if _has_nulltype(schema): + raise ValueError("Some of types cannot be determined after inferring") + return schema + + def _inferSchema(self, rdd, samplingRatio=None): + """ + Infer schema from an RDD of Row or tuple. + + :param rdd: an RDD of Row or tuple + :param samplingRatio: sampling ratio, or no sampling (default) + :return: StructType + """ + first = rdd.first() + if not first: + raise ValueError("The first row in RDD is empty, " + "can not infer schema") + if type(first) is dict: + warnings.warn("Using RDD of dict to inferSchema is deprecated. " + "Use pyspark.sql.Row instead") + + if samplingRatio is None: + schema = _infer_schema(first) + if _has_nulltype(schema): + for row in rdd.take(100)[1:]: + schema = _merge_type(schema, _infer_schema(row)) + if not _has_nulltype(schema): + break + else: + raise ValueError("Some of types cannot be determined by the " + "first 100 rows, please try again with sampling") + else: + if samplingRatio < 0.99: + rdd = rdd.sample(False, float(samplingRatio)) + schema = rdd.map(_infer_schema).reduce(_merge_type) + return schema + + def _createFromRDD(self, rdd, schema, samplingRatio): + """ + Create an RDD for DataFrame from an existing RDD, returns the RDD and schema. + """ + if schema is None or isinstance(schema, (list, tuple)): + struct = self._inferSchema(rdd, samplingRatio) + converter = _create_converter(struct) + rdd = rdd.map(converter) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + struct.names[i] = name + schema = struct + + elif not isinstance(schema, StructType): + raise TypeError("schema should be StructType or list or None, but got: %s" % schema) + + # convert python objects to sql data + rdd = rdd.map(schema.toInternal) + return rdd, schema + + def _createFromLocal(self, data, schema): + """ + Create an RDD for DataFrame from an list or pandas.DataFrame, returns + the RDD and schema. + """ + # make sure data could consumed multiple times + if not isinstance(data, list): + data = list(data) + + if schema is None or isinstance(schema, (list, tuple)): + struct = self._inferSchemaFromList(data) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + struct.names[i] = name + schema = struct + + elif isinstance(schema, StructType): + for row in data: + _verify_type(row, schema) + + else: + raise TypeError("schema should be StructType or list or None, but got: %s" % schema) + + # convert python objects to sql data + data = [schema.toInternal(row) for row in data] + return self._sc.parallelize(data), schema + + @since(2.0) + @ignore_unicode_prefix + def createDataFrame(self, data, schema=None, samplingRatio=None): + """ + Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`. + + When ``schema`` is a list of column names, the type of each column + will be inferred from ``data``. + + When ``schema`` is ``None``, it will try to infer the schema (column names and types) + from ``data``, which should be an RDD of :class:`Row`, + or :class:`namedtuple`, or :class:`dict`. + + When ``schema`` is :class:`DataType` or datatype string, it must match the real data, or + exception will be thrown at runtime. If the given schema is not StructType, it will be + wrapped into a StructType as its only field, and the field name will be "value", each record + will also be wrapped into a tuple, which can be converted to row later. + + If schema inference is needed, ``samplingRatio`` is used to determined the ratio of + rows used for schema inference. The first row will be used if ``samplingRatio`` is ``None``. + + :param data: an RDD of any kind of SQL data representation(e.g. row, tuple, int, boolean, + etc.), or :class:`list`, or :class:`pandas.DataFrame`. + :param schema: a :class:`DataType` or a datatype string or a list of column names, default + is None. The data type string format equals to `DataType.simpleString`, except that + top level struct type can omit the `struct<>` and atomic types use `typeName()` as + their format, e.g. use `byte` instead of `tinyint` for ByteType. We can also use `int` + as a short name for IntegerType. + :param samplingRatio: the sample ratio of rows used for inferring + :return: :class:`DataFrame` + + .. versionchanged:: 2.0 + The schema parameter can be a DataType or a datatype string after 2.0. If it's not a + StructType, it will be wrapped into a StructType and each record will also be wrapped + into a tuple. + + >>> l = [('Alice', 1)] + >>> spark.createDataFrame(l).collect() + [Row(_1=u'Alice', _2=1)] + >>> spark.createDataFrame(l, ['name', 'age']).collect() + [Row(name=u'Alice', age=1)] + + >>> d = [{'name': 'Alice', 'age': 1}] + >>> spark.createDataFrame(d).collect() + [Row(age=1, name=u'Alice')] + + >>> rdd = sc.parallelize(l) + >>> spark.createDataFrame(rdd).collect() + [Row(_1=u'Alice', _2=1)] + >>> df = spark.createDataFrame(rdd, ['name', 'age']) + >>> df.collect() + [Row(name=u'Alice', age=1)] + + >>> from pyspark.sql import Row + >>> Person = Row('name', 'age') + >>> person = rdd.map(lambda r: Person(*r)) + >>> df2 = spark.createDataFrame(person) + >>> df2.collect() + [Row(name=u'Alice', age=1)] + + >>> from pyspark.sql.types import * + >>> schema = StructType([ + ... StructField("name", StringType(), True), + ... StructField("age", IntegerType(), True)]) + >>> df3 = spark.createDataFrame(rdd, schema) + >>> df3.collect() + [Row(name=u'Alice', age=1)] + + >>> spark.createDataFrame(df.toPandas()).collect() # doctest: +SKIP + [Row(name=u'Alice', age=1)] + >>> spark.createDataFrame(pandas.DataFrame([[1, 2]])).collect() # doctest: +SKIP + [Row(0=1, 1=2)] + + >>> spark.createDataFrame(rdd, "a: string, b: int").collect() + [Row(a=u'Alice', b=1)] + >>> rdd = rdd.map(lambda row: row[1]) + >>> spark.createDataFrame(rdd, "int").collect() + [Row(value=1)] + >>> spark.createDataFrame(rdd, "boolean").collect() # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + Py4JJavaError: ... + """ + if isinstance(data, DataFrame): + raise TypeError("data is already a DataFrame") + + if isinstance(schema, basestring): + schema = _parse_datatype_string(schema) + + try: + import pandas + has_pandas = True + except Exception: + has_pandas = False + if has_pandas and isinstance(data, pandas.DataFrame): + if schema is None: + schema = [str(x) for x in data.columns] + data = [r.tolist() for r in data.to_records(index=False)] + + if isinstance(schema, StructType): + def prepare(obj): + _verify_type(obj, schema) + return obj + elif isinstance(schema, DataType): + datatype = schema + + def prepare(obj): + _verify_type(obj, datatype) + return (obj, ) + schema = StructType().add("value", datatype) + else: + prepare = lambda obj: obj + + if isinstance(data, RDD): + rdd, schema = self._createFromRDD(data.map(prepare), schema, samplingRatio) + else: + rdd, schema = self._createFromLocal(map(prepare, data), schema) + jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) + jdf = self._jsparkSession.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) + df = DataFrame(jdf, self._wrapped) + df._schema = schema + return df + + @since(2.0) + def registerDataFrameAsTable(self, df, tableName): + """Registers the given :class:`DataFrame` as a temporary table in the catalog. + + Temporary tables exist only during the lifetime of this instance of :class:`SparkSession`. + + >>> spark.registerDataFrameAsTable(df, "table1") + """ + if (df.__class__ is DataFrame): + self._jsparkSession.registerDataFrameAsTable(df._jdf, tableName) + else: + raise ValueError("Can only register DataFrame as table") + + @since(2.0) + def createExternalTable(self, tableName, path=None, source=None, schema=None, **options): + """Creates an external table based on the dataset in a data source. + + It returns the DataFrame associated with the external table. + + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. + + Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and + created external table. + + :return: :class:`DataFrame` + """ + if path is not None: + options["path"] = path + if source is None: + source = self.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + if schema is None: + df = self._jsparkSession.catalog().createExternalTable(tableName, source, options) + else: + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + scala_datatype = self._jsparkSession.parseDataType(schema.json()) + df = self._jsparkSession.catalog().createExternalTable( + tableName, source, scala_datatype, options) + return DataFrame(df, self._wrapped) + + @ignore_unicode_prefix + @since(2.0) + def sql(self, sqlQuery): + """Returns a :class:`DataFrame` representing the result of the given query. + + :return: :class:`DataFrame` + + >>> spark.registerDataFrameAsTable(df, "table1") + >>> df2 = spark.sql("SELECT field1 AS f1, field2 as f2 from table1") + >>> df2.collect() + [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] + """ + return DataFrame(self._jsparkSession.sql(sqlQuery), self._wrapped) + + @since(2.0) + def table(self, tableName): + """Returns the specified table as a :class:`DataFrame`. + + :return: :class:`DataFrame` + + >>> spark.registerDataFrameAsTable(df, "table1") + >>> df2 = spark.table("table1") + >>> sorted(df.collect()) == sorted(df2.collect()) + True + """ + return DataFrame(self._jsparkSession.table(tableName), self._wrapped) + + @property + @since(2.0) + def read(self): + """ + Returns a :class:`DataFrameReader` that can be used to read data + in as a :class:`DataFrame`. + + :return: :class:`DataFrameReader` + """ + return DataFrameReader(self._wrapped) |