aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Or <andrew@databricks.com>2016-04-28 10:55:48 -0700
committerReynold Xin <rxin@databricks.com>2016-04-28 10:55:48 -0700
commit89addd40abdacd65cc03ac8aa5f9cf3dd4a4c19b (patch)
tree5ecd3d9a736333c7951de6159eefef86129e3744
parent5743352a28fffbfbaca2201208ce7a1d7893f813 (diff)
downloadspark-89addd40abdacd65cc03ac8aa5f9cf3dd4a4c19b.tar.gz
spark-89addd40abdacd65cc03ac8aa5f9cf3dd4a4c19b.tar.bz2
spark-89addd40abdacd65cc03ac8aa5f9cf3dd4a4c19b.zip
[SPARK-14945][PYTHON] SparkSession Python API
## What changes were proposed in this pull request? ``` Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /__ / .__/\_,_/_/ /_/\_\ version 2.0.0-SNAPSHOT /_/ Using Python version 2.7.5 (default, Mar 9 2014 22:15:05) SparkSession available as 'spark'. >>> spark <pyspark.sql.session.SparkSession object at 0x101f3bfd0> >>> spark.sql("SHOW TABLES").show() ... +---------+-----------+ |tableName|isTemporary| +---------+-----------+ | src| false| +---------+-----------+ >>> spark.range(1, 10, 2).show() +---+ | id| +---+ | 1| | 3| | 5| | 7| | 9| +---+ ``` **Note**: This API is NOT complete in its current state. In particular, for now I left out the `conf` and `catalog` APIs, which were added later in Scala. These will be added later before 2.0. ## How was this patch tested? Python tests. Author: Andrew Or <andrew@databricks.com> Closes #12746 from andrewor14/python-spark-session.
-rw-r--r--python/pyspark/shell.py11
-rw-r--r--python/pyspark/sql/__init__.py5
-rw-r--r--python/pyspark/sql/context.py278
-rw-r--r--python/pyspark/sql/readwriter.py2
-rw-r--r--python/pyspark/sql/session.py525
-rw-r--r--python/pyspark/sql/tests.py4
6 files changed, 585 insertions, 240 deletions
diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
index 7c37f75193..c6b0eda996 100644
--- a/python/pyspark/shell.py
+++ b/python/pyspark/shell.py
@@ -29,7 +29,7 @@ import py4j
import pyspark
from pyspark.context import SparkContext
-from pyspark.sql import SQLContext, HiveContext
+from pyspark.sql import SparkSession, SQLContext
from pyspark.storagelevel import StorageLevel
if os.environ.get("SPARK_EXECUTOR_URI"):
@@ -41,13 +41,14 @@ atexit.register(lambda: sc.stop())
try:
# Try to access HiveConf, it will raise exception if Hive is not added
sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
- sqlContext = HiveContext(sc)
+ spark = SparkSession.withHiveSupport(sc)
except py4j.protocol.Py4JError:
- sqlContext = SQLContext(sc)
+ spark = SparkSession(sc)
except TypeError:
- sqlContext = SQLContext(sc)
+ spark = SparkSession(sc)
# for compatibility
+sqlContext = spark._wrapped
sqlCtx = sqlContext
print("""Welcome to
@@ -61,7 +62,7 @@ print("Using Python version %s (%s, %s)" % (
platform.python_version(),
platform.python_build()[0],
platform.python_build()[1]))
-print("SparkContext available as sc, %s available as sqlContext." % sqlContext.__class__.__name__)
+print("SparkSession available as 'spark'.")
# The ./bin/pyspark script stores the old PYTHONSTARTUP value in OLD_PYTHONSTARTUP,
# which allows us to execute the user's PYTHONSTARTUP file:
diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py
index 0b06c8339f..cff73ff192 100644
--- a/python/pyspark/sql/__init__.py
+++ b/python/pyspark/sql/__init__.py
@@ -46,6 +46,7 @@ from __future__ import absolute_import
from pyspark.sql.types import Row
from pyspark.sql.context import SQLContext, HiveContext
+from pyspark.sql.session import SparkSession
from pyspark.sql.column import Column
from pyspark.sql.dataframe import DataFrame, DataFrameNaFunctions, DataFrameStatFunctions
from pyspark.sql.group import GroupedData
@@ -54,7 +55,7 @@ from pyspark.sql.window import Window, WindowSpec
__all__ = [
- 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row',
- 'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec',
+ 'SparkSession', 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column',
+ 'Row', 'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec',
'DataFrameReader', 'DataFrameWriter'
]
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 600a6e0bc2..48ffb59668 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -17,70 +17,37 @@
from __future__ import print_function
import sys
-import warnings
-import json
-from functools import reduce
if sys.version >= '3':
basestring = unicode = str
-else:
- from itertools import imap as map
-
-from py4j.protocol import Py4JError
from pyspark import since
-from pyspark.rdd import RDD, ignore_unicode_prefix
-from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
-from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \
- _infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string
+from pyspark.rdd import ignore_unicode_prefix
+from pyspark.sql.session import _monkey_patch_RDD, SparkSession
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.readwriter import DataFrameReader
+from pyspark.sql.types import Row, StringType
from pyspark.sql.utils import install_exception_handler
-from pyspark.sql.functions import UserDefinedFunction
-
-try:
- import pandas
- has_pandas = True
-except Exception:
- has_pandas = False
__all__ = ["SQLContext", "HiveContext", "UDFRegistration"]
-def _monkey_patch_RDD(sqlContext):
- def toDF(self, schema=None, sampleRatio=None):
- """
- Converts current :class:`RDD` into a :class:`DataFrame`
-
- This is a shorthand for ``sqlContext.createDataFrame(rdd, schema, sampleRatio)``
-
- :param schema: a StructType or list of names of columns
- :param samplingRatio: the sample ratio of rows used for inferring
- :return: a DataFrame
-
- >>> rdd.toDF().collect()
- [Row(name=u'Alice', age=1)]
- """
- return sqlContext.createDataFrame(self, schema, sampleRatio)
-
- RDD.toDF = toDF
-
-
class SQLContext(object):
- """Main entry point for Spark SQL functionality.
+ """Wrapper around :class:`SparkSession`, the main entry point to Spark SQL functionality.
A SQLContext can be used create :class:`DataFrame`, register :class:`DataFrame` as
tables, execute SQL over tables, cache tables, and read parquet files.
:param sparkContext: The :class:`SparkContext` backing this SQLContext.
- :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instantiate a new
+ :param sparkSession: The :class:`SparkSession` around which this SQLContext wraps.
+ :param jsqlContext: An optional JVM Scala SQLContext. If set, we do not instantiate a new
SQLContext in the JVM, instead we make all calls to this object.
"""
_instantiatedContext = None
@ignore_unicode_prefix
- def __init__(self, sparkContext, sqlContext=None):
+ def __init__(self, sparkContext, sparkSession=None, jsqlContext=None):
"""Creates a new SQLContext.
>>> from datetime import datetime
@@ -100,8 +67,13 @@ class SQLContext(object):
self._sc = sparkContext
self._jsc = self._sc._jsc
self._jvm = self._sc._jvm
- self._scala_SQLContext = sqlContext
- _monkey_patch_RDD(self)
+ if sparkSession is None:
+ sparkSession = SparkSession(sparkContext)
+ if jsqlContext is None:
+ jsqlContext = sparkSession._jwrapped
+ self.sparkSession = sparkSession
+ self._jsqlContext = jsqlContext
+ _monkey_patch_RDD(self.sparkSession)
install_exception_handler()
if SQLContext._instantiatedContext is None:
SQLContext._instantiatedContext = self
@@ -113,9 +85,7 @@ class SQLContext(object):
Subclasses can override this property to provide their own
JVM Contexts.
"""
- if self._scala_SQLContext is None:
- self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
- return self._scala_SQLContext
+ return self._jsqlContext
@classmethod
@since(1.6)
@@ -127,7 +97,8 @@ class SQLContext(object):
"""
if cls._instantiatedContext is None:
jsqlContext = sc._jvm.SQLContext.getOrCreate(sc._jsc.sc())
- cls(sc, jsqlContext)
+ sparkSession = SparkSession(sc, jsqlContext.sparkSession())
+ cls(sc, sparkSession, jsqlContext)
return cls._instantiatedContext
@since(1.6)
@@ -137,14 +108,13 @@ class SQLContext(object):
registered temporary tables and UDFs, but shared SparkContext and
table cache.
"""
- jsqlContext = self._ssql_ctx.newSession()
- return self.__class__(self._sc, jsqlContext)
+ return self.__class__(self._sc, self.sparkSession.newSession())
@since(1.3)
def setConf(self, key, value):
"""Sets the given Spark SQL configuration property.
"""
- self._ssql_ctx.setConf(key, value)
+ self.sparkSession.setConf(key, value)
@ignore_unicode_prefix
@since(1.3)
@@ -163,10 +133,7 @@ class SQLContext(object):
>>> sqlContext.getConf("spark.sql.shuffle.partitions", "10")
u'50'
"""
- if defaultValue is not None:
- return self._ssql_ctx.getConf(key, defaultValue)
- else:
- return self._ssql_ctx.getConf(key)
+ return self.sparkSession.getConf(key, defaultValue)
@property
@since("1.3.1")
@@ -175,7 +142,7 @@ class SQLContext(object):
:return: :class:`UDFRegistration`
"""
- return UDFRegistration(self)
+ return UDFRegistration(self.sparkSession)
@since(1.4)
def range(self, start, end=None, step=1, numPartitions=None):
@@ -198,15 +165,7 @@ class SQLContext(object):
>>> sqlContext.range(3).collect()
[Row(id=0), Row(id=1), Row(id=2)]
"""
- if numPartitions is None:
- numPartitions = self._sc.defaultParallelism
-
- if end is None:
- jdf = self._ssql_ctx.range(0, int(start), int(step), int(numPartitions))
- else:
- jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions))
-
- return DataFrame(jdf, self)
+ return self.sparkSession.range(start, end, step, numPartitions)
@ignore_unicode_prefix
@since(1.2)
@@ -236,27 +195,9 @@ class SQLContext(object):
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
"""
- udf = UserDefinedFunction(f, returnType, name)
- self._ssql_ctx.udf().registerPython(name, udf._judf)
-
- def _inferSchemaFromList(self, data):
- """
- Infer schema from list of Row or tuple.
-
- :param data: list of Row or tuple
- :return: StructType
- """
- if not data:
- raise ValueError("can not infer schema from empty dataset")
- first = data[0]
- if type(first) is dict:
- warnings.warn("inferring schema from dict is deprecated,"
- "please use pyspark.sql.Row instead")
- schema = reduce(_merge_type, map(_infer_schema, data))
- if _has_nulltype(schema):
- raise ValueError("Some of types cannot be determined after inferring")
- return schema
+ self.sparkSession.registerFunction(name, f, returnType)
+ # TODO(andrew): delete this once we refactor things to take in SparkSession
def _inferSchema(self, rdd, samplingRatio=None):
"""
Infer schema from an RDD of Row or tuple.
@@ -265,78 +206,7 @@ class SQLContext(object):
:param samplingRatio: sampling ratio, or no sampling (default)
:return: StructType
"""
- first = rdd.first()
- if not first:
- raise ValueError("The first row in RDD is empty, "
- "can not infer schema")
- if type(first) is dict:
- warnings.warn("Using RDD of dict to inferSchema is deprecated. "
- "Use pyspark.sql.Row instead")
-
- if samplingRatio is None:
- schema = _infer_schema(first)
- if _has_nulltype(schema):
- for row in rdd.take(100)[1:]:
- schema = _merge_type(schema, _infer_schema(row))
- if not _has_nulltype(schema):
- break
- else:
- raise ValueError("Some of types cannot be determined by the "
- "first 100 rows, please try again with sampling")
- else:
- if samplingRatio < 0.99:
- rdd = rdd.sample(False, float(samplingRatio))
- schema = rdd.map(_infer_schema).reduce(_merge_type)
- return schema
-
- def _createFromRDD(self, rdd, schema, samplingRatio):
- """
- Create an RDD for DataFrame from an existing RDD, returns the RDD and schema.
- """
- if schema is None or isinstance(schema, (list, tuple)):
- struct = self._inferSchema(rdd, samplingRatio)
- converter = _create_converter(struct)
- rdd = rdd.map(converter)
- if isinstance(schema, (list, tuple)):
- for i, name in enumerate(schema):
- struct.fields[i].name = name
- struct.names[i] = name
- schema = struct
-
- elif not isinstance(schema, StructType):
- raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
-
- # convert python objects to sql data
- rdd = rdd.map(schema.toInternal)
- return rdd, schema
-
- def _createFromLocal(self, data, schema):
- """
- Create an RDD for DataFrame from an list or pandas.DataFrame, returns
- the RDD and schema.
- """
- # make sure data could consumed multiple times
- if not isinstance(data, list):
- data = list(data)
-
- if schema is None or isinstance(schema, (list, tuple)):
- struct = self._inferSchemaFromList(data)
- if isinstance(schema, (list, tuple)):
- for i, name in enumerate(schema):
- struct.fields[i].name = name
- struct.names[i] = name
- schema = struct
-
- elif isinstance(schema, StructType):
- for row in data:
- _verify_type(row, schema)
-
- else:
- raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
-
- # convert python objects to sql data
- data = [schema.toInternal(row) for row in data]
- return self._sc.parallelize(data), schema
+ return self.sparkSession._inferSchema(rdd, samplingRatio)
@since(1.3)
@ignore_unicode_prefix
@@ -421,40 +291,7 @@ class SQLContext(object):
...
Py4JJavaError: ...
"""
- if isinstance(data, DataFrame):
- raise TypeError("data is already a DataFrame")
-
- if isinstance(schema, basestring):
- schema = _parse_datatype_string(schema)
-
- if has_pandas and isinstance(data, pandas.DataFrame):
- if schema is None:
- schema = [str(x) for x in data.columns]
- data = [r.tolist() for r in data.to_records(index=False)]
-
- if isinstance(schema, StructType):
- def prepare(obj):
- _verify_type(obj, schema)
- return obj
- elif isinstance(schema, DataType):
- datatype = schema
-
- def prepare(obj):
- _verify_type(obj, datatype)
- return (obj, )
- schema = StructType().add("value", datatype)
- else:
- prepare = lambda obj: obj
-
- if isinstance(data, RDD):
- rdd, schema = self._createFromRDD(data.map(prepare), schema, samplingRatio)
- else:
- rdd, schema = self._createFromLocal(map(prepare, data), schema)
- jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
- jdf = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
- df = DataFrame(jdf, self)
- df._schema = schema
- return df
+ return self.sparkSession.createDataFrame(data, schema, samplingRatio)
@since(1.3)
def registerDataFrameAsTable(self, df, tableName):
@@ -464,10 +301,7 @@ class SQLContext(object):
>>> sqlContext.registerDataFrameAsTable(df, "table1")
"""
- if (df.__class__ is DataFrame):
- self._ssql_ctx.registerDataFrameAsTable(df._jdf, tableName)
- else:
- raise ValueError("Can only register DataFrame as table")
+ self.sparkSession.registerDataFrameAsTable(df, tableName)
@since(1.6)
def dropTempTable(self, tableName):
@@ -493,20 +327,7 @@ class SQLContext(object):
:return: :class:`DataFrame`
"""
- if path is not None:
- options["path"] = path
- if source is None:
- source = self.getConf("spark.sql.sources.default",
- "org.apache.spark.sql.parquet")
- if schema is None:
- df = self._ssql_ctx.createExternalTable(tableName, source, options)
- else:
- if not isinstance(schema, StructType):
- raise TypeError("schema should be StructType")
- scala_datatype = self._ssql_ctx.parseDataType(schema.json())
- df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype,
- options)
- return DataFrame(df, self)
+ return self.sparkSession.createExternalTable(tableName, path, source, schema, **options)
@ignore_unicode_prefix
@since(1.0)
@@ -520,7 +341,7 @@ class SQLContext(object):
>>> df2.collect()
[Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
"""
- return DataFrame(self._ssql_ctx.sql(sqlQuery), self)
+ return self.sparkSession.sql(sqlQuery)
@since(1.0)
def table(self, tableName):
@@ -533,7 +354,7 @@ class SQLContext(object):
>>> sorted(df.collect()) == sorted(df2.collect())
True
"""
- return DataFrame(self._ssql_ctx.table(tableName), self)
+ return self.sparkSession.table(tableName)
@ignore_unicode_prefix
@since(1.3)
@@ -603,7 +424,7 @@ class SQLContext(object):
return DataFrameReader(self)
-# TODO(andrew): remove this too
+# TODO(andrew): deprecate this
class HiveContext(SQLContext):
"""A variant of Spark SQL that integrates with data stored in Hive.
@@ -611,29 +432,28 @@ class HiveContext(SQLContext):
It supports running both SQL and HiveQL commands.
:param sparkContext: The SparkContext to wrap.
- :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instantiate a new
+ :param jhiveContext: An optional JVM Scala HiveContext. If set, we do not instantiate a new
:class:`HiveContext` in the JVM, instead we make all calls to this object.
"""
- def __init__(self, sparkContext, hiveContext=None):
- SQLContext.__init__(self, sparkContext)
- if hiveContext:
- self._scala_HiveContext = hiveContext
+ def __init__(self, sparkContext, jhiveContext=None):
+ if jhiveContext is None:
+ sparkSession = SparkSession.withHiveSupport(sparkContext)
+ else:
+ sparkSession = SparkSession(sparkContext, jhiveContext.sparkSession())
+ SQLContext.__init__(self, sparkContext, sparkSession, jhiveContext)
+
+ @classmethod
+ def _createForTesting(cls, sparkContext):
+ """(Internal use only) Create a new HiveContext for testing.
- @property
- def _ssql_ctx(self):
- try:
- if not hasattr(self, '_scala_HiveContext'):
- self._scala_HiveContext = self._get_hive_ctx()
- return self._scala_HiveContext
- except Py4JError as e:
- print("You must build Spark with Hive. "
- "Export 'SPARK_HIVE=true' and run "
- "build/sbt assembly", file=sys.stderr)
- raise
-
- def _get_hive_ctx(self):
- return self._jvm.SparkSession.withHiveSupport(self._jsc.sc()).wrapped()
+ All test code that touches HiveContext *must* go through this method. Otherwise,
+ you may end up launching multiple derby instances and encounter with incredibly
+ confusing error messages.
+ """
+ jsc = sparkContext._jsc.sc()
+ jtestHive = sparkContext._jvm.org.apache.spark.sql.hive.test.TestHiveContext(jsc)
+ return cls(sparkContext, jtestHive)
def refreshTable(self, tableName):
"""Invalidate and refresh all the cached the metadata of the given
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index e39cf1ae03..784609e4c5 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -743,7 +743,7 @@ def _test():
globs['os'] = os
globs['sc'] = sc
globs['sqlContext'] = SQLContext(sc)
- globs['hiveContext'] = HiveContext(sc)
+ globs['hiveContext'] = HiveContext._createForTesting(sc)
globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned')
globs['sdf'] =\
globs['sqlContext'].read.format('text').stream('python/test_support/sql/streaming')
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
new file mode 100644
index 0000000000..d3355f9da7
--- /dev/null
+++ b/python/pyspark/sql/session.py
@@ -0,0 +1,525 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+import sys
+import warnings
+from functools import reduce
+
+if sys.version >= '3':
+ basestring = unicode = str
+else:
+ from itertools import imap as map
+
+from pyspark import since
+from pyspark.rdd import RDD, ignore_unicode_prefix
+from pyspark.sql.dataframe import DataFrame
+from pyspark.sql.functions import UserDefinedFunction
+from pyspark.sql.readwriter import DataFrameReader
+from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \
+ _infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string
+from pyspark.sql.utils import install_exception_handler
+
+__all__ = ["SparkSession"]
+
+
+def _monkey_patch_RDD(sparkSession):
+ def toDF(self, schema=None, sampleRatio=None):
+ """
+ Converts current :class:`RDD` into a :class:`DataFrame`
+
+ This is a shorthand for ``spark.createDataFrame(rdd, schema, sampleRatio)``
+
+ :param schema: a StructType or list of names of columns
+ :param samplingRatio: the sample ratio of rows used for inferring
+ :return: a DataFrame
+
+ >>> rdd.toDF().collect()
+ [Row(name=u'Alice', age=1)]
+ """
+ return sparkSession.createDataFrame(self, schema, sampleRatio)
+
+ RDD.toDF = toDF
+
+
+# TODO(andrew): implement conf and catalog namespaces
+class SparkSession(object):
+ """Main entry point for Spark SQL functionality.
+
+ A SparkSession can be used create :class:`DataFrame`, register :class:`DataFrame` as
+ tables, execute SQL over tables, cache tables, and read parquet files.
+
+ :param sparkContext: The :class:`SparkContext` backing this SparkSession.
+ :param jsparkSession: An optional JVM Scala SparkSession. If set, we do not instantiate a new
+ SparkSession in the JVM, instead we make all calls to this object.
+ """
+
+ _instantiatedContext = None
+
+ @ignore_unicode_prefix
+ def __init__(self, sparkContext, jsparkSession=None):
+ """Creates a new SparkSession.
+
+ >>> from datetime import datetime
+ >>> spark = SparkSession(sc)
+ >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1,
+ ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),
+ ... time=datetime(2014, 8, 1, 14, 1, 5))])
+ >>> df = allTypes.toDF()
+ >>> df.registerTempTable("allTypes")
+ >>> spark.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
+ ... 'from allTypes where b and i > 0').collect()
+ [Row((i + CAST(1 AS BIGINT))=2, (d + CAST(1 AS DOUBLE))=2.0, (NOT b)=False, list[1]=2, \
+ dict[s]=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
+ >>> df.rdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect()
+ [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
+ """
+ from pyspark.sql.context import SQLContext
+ self._sc = sparkContext
+ self._jsc = self._sc._jsc
+ self._jvm = self._sc._jvm
+ if jsparkSession is None:
+ jsparkSession = self._jvm.SparkSession(self._jsc.sc())
+ self._jsparkSession = jsparkSession
+ self._jwrapped = self._jsparkSession.wrapped()
+ self._wrapped = SQLContext(self._sc, self, self._jwrapped)
+ _monkey_patch_RDD(self)
+ install_exception_handler()
+ if SparkSession._instantiatedContext is None:
+ SparkSession._instantiatedContext = self
+
+ @classmethod
+ @since(2.0)
+ def withHiveSupport(cls, sparkContext):
+ """Returns a new SparkSession with a catalog backed by Hive
+
+ :param sparkContext: The underlying :class:`SparkContext`.
+ """
+ jsparkSession = sparkContext._jvm.SparkSession.withHiveSupport(sparkContext._jsc.sc())
+ return cls(sparkContext, jsparkSession)
+
+ @since(2.0)
+ def newSession(self):
+ """
+ Returns a new SparkSession as new session, that has separate SQLConf,
+ registered temporary tables and UDFs, but shared SparkContext and
+ table cache.
+ """
+ return self.__class__(self._sc, self._jsparkSession.newSession())
+
+ @since(2.0)
+ def setConf(self, key, value):
+ """
+ Sets the given Spark SQL configuration property.
+ """
+ self._jsparkSession.setConf(key, value)
+
+ @ignore_unicode_prefix
+ @since(2.0)
+ def getConf(self, key, defaultValue=None):
+ """Returns the value of Spark SQL configuration property for the given key.
+
+ If the key is not set and defaultValue is not None, return
+ defaultValue. If the key is not set and defaultValue is None, return
+ the system default value.
+
+ >>> spark.getConf("spark.sql.shuffle.partitions")
+ u'200'
+ >>> spark.getConf("spark.sql.shuffle.partitions", "10")
+ u'10'
+ >>> spark.setConf("spark.sql.shuffle.partitions", "50")
+ >>> spark.getConf("spark.sql.shuffle.partitions", "10")
+ u'50'
+ """
+ if defaultValue is not None:
+ return self._jsparkSession.getConf(key, defaultValue)
+ else:
+ return self._jsparkSession.getConf(key)
+
+ @property
+ @since(2.0)
+ def udf(self):
+ """Returns a :class:`UDFRegistration` for UDF registration.
+
+ :return: :class:`UDFRegistration`
+ """
+ return UDFRegistration(self)
+
+ @since(2.0)
+ def range(self, start, end=None, step=1, numPartitions=None):
+ """
+ Create a :class:`DataFrame` with single LongType column named `id`,
+ containing elements in a range from `start` to `end` (exclusive) with
+ step value `step`.
+
+ :param start: the start value
+ :param end: the end value (exclusive)
+ :param step: the incremental step (default: 1)
+ :param numPartitions: the number of partitions of the DataFrame
+ :return: :class:`DataFrame`
+
+ >>> spark.range(1, 7, 2).collect()
+ [Row(id=1), Row(id=3), Row(id=5)]
+
+ If only one argument is specified, it will be used as the end value.
+
+ >>> spark.range(3).collect()
+ [Row(id=0), Row(id=1), Row(id=2)]
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+
+ if end is None:
+ jdf = self._jsparkSession.range(0, int(start), int(step), int(numPartitions))
+ else:
+ jdf = self._jsparkSession.range(int(start), int(end), int(step), int(numPartitions))
+
+ return DataFrame(jdf, self._wrapped)
+
+ @ignore_unicode_prefix
+ @since(2.0)
+ def registerFunction(self, name, f, returnType=StringType()):
+ """Registers a python function (including lambda function) as a UDF
+ so it can be used in SQL statements.
+
+ In addition to a name and the function itself, the return type can be optionally specified.
+ When the return type is not given it default to a string and conversion will automatically
+ be done. For any other return type, the produced object must match the specified type.
+
+ :param name: name of the UDF
+ :param f: python function
+ :param returnType: a :class:`DataType` object
+
+ >>> spark.registerFunction("stringLengthString", lambda x: len(x))
+ >>> spark.sql("SELECT stringLengthString('test')").collect()
+ [Row(stringLengthString(test)=u'4')]
+
+ >>> from pyspark.sql.types import IntegerType
+ >>> spark.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
+ >>> spark.sql("SELECT stringLengthInt('test')").collect()
+ [Row(stringLengthInt(test)=4)]
+
+ >>> from pyspark.sql.types import IntegerType
+ >>> spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
+ >>> spark.sql("SELECT stringLengthInt('test')").collect()
+ [Row(stringLengthInt(test)=4)]
+ """
+ udf = UserDefinedFunction(f, returnType, name)
+ self._jsparkSession.udf().registerPython(name, udf._judf)
+
+ def _inferSchemaFromList(self, data):
+ """
+ Infer schema from list of Row or tuple.
+
+ :param data: list of Row or tuple
+ :return: StructType
+ """
+ if not data:
+ raise ValueError("can not infer schema from empty dataset")
+ first = data[0]
+ if type(first) is dict:
+ warnings.warn("inferring schema from dict is deprecated,"
+ "please use pyspark.sql.Row instead")
+ schema = reduce(_merge_type, map(_infer_schema, data))
+ if _has_nulltype(schema):
+ raise ValueError("Some of types cannot be determined after inferring")
+ return schema
+
+ def _inferSchema(self, rdd, samplingRatio=None):
+ """
+ Infer schema from an RDD of Row or tuple.
+
+ :param rdd: an RDD of Row or tuple
+ :param samplingRatio: sampling ratio, or no sampling (default)
+ :return: StructType
+ """
+ first = rdd.first()
+ if not first:
+ raise ValueError("The first row in RDD is empty, "
+ "can not infer schema")
+ if type(first) is dict:
+ warnings.warn("Using RDD of dict to inferSchema is deprecated. "
+ "Use pyspark.sql.Row instead")
+
+ if samplingRatio is None:
+ schema = _infer_schema(first)
+ if _has_nulltype(schema):
+ for row in rdd.take(100)[1:]:
+ schema = _merge_type(schema, _infer_schema(row))
+ if not _has_nulltype(schema):
+ break
+ else:
+ raise ValueError("Some of types cannot be determined by the "
+ "first 100 rows, please try again with sampling")
+ else:
+ if samplingRatio < 0.99:
+ rdd = rdd.sample(False, float(samplingRatio))
+ schema = rdd.map(_infer_schema).reduce(_merge_type)
+ return schema
+
+ def _createFromRDD(self, rdd, schema, samplingRatio):
+ """
+ Create an RDD for DataFrame from an existing RDD, returns the RDD and schema.
+ """
+ if schema is None or isinstance(schema, (list, tuple)):
+ struct = self._inferSchema(rdd, samplingRatio)
+ converter = _create_converter(struct)
+ rdd = rdd.map(converter)
+ if isinstance(schema, (list, tuple)):
+ for i, name in enumerate(schema):
+ struct.fields[i].name = name
+ struct.names[i] = name
+ schema = struct
+
+ elif not isinstance(schema, StructType):
+ raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
+
+ # convert python objects to sql data
+ rdd = rdd.map(schema.toInternal)
+ return rdd, schema
+
+ def _createFromLocal(self, data, schema):
+ """
+ Create an RDD for DataFrame from an list or pandas.DataFrame, returns
+ the RDD and schema.
+ """
+ # make sure data could consumed multiple times
+ if not isinstance(data, list):
+ data = list(data)
+
+ if schema is None or isinstance(schema, (list, tuple)):
+ struct = self._inferSchemaFromList(data)
+ if isinstance(schema, (list, tuple)):
+ for i, name in enumerate(schema):
+ struct.fields[i].name = name
+ struct.names[i] = name
+ schema = struct
+
+ elif isinstance(schema, StructType):
+ for row in data:
+ _verify_type(row, schema)
+
+ else:
+ raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
+
+ # convert python objects to sql data
+ data = [schema.toInternal(row) for row in data]
+ return self._sc.parallelize(data), schema
+
+ @since(2.0)
+ @ignore_unicode_prefix
+ def createDataFrame(self, data, schema=None, samplingRatio=None):
+ """
+ Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`.
+
+ When ``schema`` is a list of column names, the type of each column
+ will be inferred from ``data``.
+
+ When ``schema`` is ``None``, it will try to infer the schema (column names and types)
+ from ``data``, which should be an RDD of :class:`Row`,
+ or :class:`namedtuple`, or :class:`dict`.
+
+ When ``schema`` is :class:`DataType` or datatype string, it must match the real data, or
+ exception will be thrown at runtime. If the given schema is not StructType, it will be
+ wrapped into a StructType as its only field, and the field name will be "value", each record
+ will also be wrapped into a tuple, which can be converted to row later.
+
+ If schema inference is needed, ``samplingRatio`` is used to determined the ratio of
+ rows used for schema inference. The first row will be used if ``samplingRatio`` is ``None``.
+
+ :param data: an RDD of any kind of SQL data representation(e.g. row, tuple, int, boolean,
+ etc.), or :class:`list`, or :class:`pandas.DataFrame`.
+ :param schema: a :class:`DataType` or a datatype string or a list of column names, default
+ is None. The data type string format equals to `DataType.simpleString`, except that
+ top level struct type can omit the `struct<>` and atomic types use `typeName()` as
+ their format, e.g. use `byte` instead of `tinyint` for ByteType. We can also use `int`
+ as a short name for IntegerType.
+ :param samplingRatio: the sample ratio of rows used for inferring
+ :return: :class:`DataFrame`
+
+ .. versionchanged:: 2.0
+ The schema parameter can be a DataType or a datatype string after 2.0. If it's not a
+ StructType, it will be wrapped into a StructType and each record will also be wrapped
+ into a tuple.
+
+ >>> l = [('Alice', 1)]
+ >>> spark.createDataFrame(l).collect()
+ [Row(_1=u'Alice', _2=1)]
+ >>> spark.createDataFrame(l, ['name', 'age']).collect()
+ [Row(name=u'Alice', age=1)]
+
+ >>> d = [{'name': 'Alice', 'age': 1}]
+ >>> spark.createDataFrame(d).collect()
+ [Row(age=1, name=u'Alice')]
+
+ >>> rdd = sc.parallelize(l)
+ >>> spark.createDataFrame(rdd).collect()
+ [Row(_1=u'Alice', _2=1)]
+ >>> df = spark.createDataFrame(rdd, ['name', 'age'])
+ >>> df.collect()
+ [Row(name=u'Alice', age=1)]
+
+ >>> from pyspark.sql import Row
+ >>> Person = Row('name', 'age')
+ >>> person = rdd.map(lambda r: Person(*r))
+ >>> df2 = spark.createDataFrame(person)
+ >>> df2.collect()
+ [Row(name=u'Alice', age=1)]
+
+ >>> from pyspark.sql.types import *
+ >>> schema = StructType([
+ ... StructField("name", StringType(), True),
+ ... StructField("age", IntegerType(), True)])
+ >>> df3 = spark.createDataFrame(rdd, schema)
+ >>> df3.collect()
+ [Row(name=u'Alice', age=1)]
+
+ >>> spark.createDataFrame(df.toPandas()).collect() # doctest: +SKIP
+ [Row(name=u'Alice', age=1)]
+ >>> spark.createDataFrame(pandas.DataFrame([[1, 2]])).collect() # doctest: +SKIP
+ [Row(0=1, 1=2)]
+
+ >>> spark.createDataFrame(rdd, "a: string, b: int").collect()
+ [Row(a=u'Alice', b=1)]
+ >>> rdd = rdd.map(lambda row: row[1])
+ >>> spark.createDataFrame(rdd, "int").collect()
+ [Row(value=1)]
+ >>> spark.createDataFrame(rdd, "boolean").collect() # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ Py4JJavaError: ...
+ """
+ if isinstance(data, DataFrame):
+ raise TypeError("data is already a DataFrame")
+
+ if isinstance(schema, basestring):
+ schema = _parse_datatype_string(schema)
+
+ try:
+ import pandas
+ has_pandas = True
+ except Exception:
+ has_pandas = False
+ if has_pandas and isinstance(data, pandas.DataFrame):
+ if schema is None:
+ schema = [str(x) for x in data.columns]
+ data = [r.tolist() for r in data.to_records(index=False)]
+
+ if isinstance(schema, StructType):
+ def prepare(obj):
+ _verify_type(obj, schema)
+ return obj
+ elif isinstance(schema, DataType):
+ datatype = schema
+
+ def prepare(obj):
+ _verify_type(obj, datatype)
+ return (obj, )
+ schema = StructType().add("value", datatype)
+ else:
+ prepare = lambda obj: obj
+
+ if isinstance(data, RDD):
+ rdd, schema = self._createFromRDD(data.map(prepare), schema, samplingRatio)
+ else:
+ rdd, schema = self._createFromLocal(map(prepare, data), schema)
+ jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
+ jdf = self._jsparkSession.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
+ df = DataFrame(jdf, self._wrapped)
+ df._schema = schema
+ return df
+
+ @since(2.0)
+ def registerDataFrameAsTable(self, df, tableName):
+ """Registers the given :class:`DataFrame` as a temporary table in the catalog.
+
+ Temporary tables exist only during the lifetime of this instance of :class:`SparkSession`.
+
+ >>> spark.registerDataFrameAsTable(df, "table1")
+ """
+ if (df.__class__ is DataFrame):
+ self._jsparkSession.registerDataFrameAsTable(df._jdf, tableName)
+ else:
+ raise ValueError("Can only register DataFrame as table")
+
+ @since(2.0)
+ def createExternalTable(self, tableName, path=None, source=None, schema=None, **options):
+ """Creates an external table based on the dataset in a data source.
+
+ It returns the DataFrame associated with the external table.
+
+ The data source is specified by the ``source`` and a set of ``options``.
+ If ``source`` is not specified, the default data source configured by
+ ``spark.sql.sources.default`` will be used.
+
+ Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and
+ created external table.
+
+ :return: :class:`DataFrame`
+ """
+ if path is not None:
+ options["path"] = path
+ if source is None:
+ source = self.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ if schema is None:
+ df = self._jsparkSession.catalog().createExternalTable(tableName, source, options)
+ else:
+ if not isinstance(schema, StructType):
+ raise TypeError("schema should be StructType")
+ scala_datatype = self._jsparkSession.parseDataType(schema.json())
+ df = self._jsparkSession.catalog().createExternalTable(
+ tableName, source, scala_datatype, options)
+ return DataFrame(df, self._wrapped)
+
+ @ignore_unicode_prefix
+ @since(2.0)
+ def sql(self, sqlQuery):
+ """Returns a :class:`DataFrame` representing the result of the given query.
+
+ :return: :class:`DataFrame`
+
+ >>> spark.registerDataFrameAsTable(df, "table1")
+ >>> df2 = spark.sql("SELECT field1 AS f1, field2 as f2 from table1")
+ >>> df2.collect()
+ [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
+ """
+ return DataFrame(self._jsparkSession.sql(sqlQuery), self._wrapped)
+
+ @since(2.0)
+ def table(self, tableName):
+ """Returns the specified table as a :class:`DataFrame`.
+
+ :return: :class:`DataFrame`
+
+ >>> spark.registerDataFrameAsTable(df, "table1")
+ >>> df2 = spark.table("table1")
+ >>> sorted(df.collect()) == sorted(df2.collect())
+ True
+ """
+ return DataFrame(self._jsparkSession.table(tableName), self._wrapped)
+
+ @property
+ @since(2.0)
+ def read(self):
+ """
+ Returns a :class:`DataFrameReader` that can be used to read data
+ in as a :class:`DataFrame`.
+
+ :return: :class:`DataFrameReader`
+ """
+ return DataFrameReader(self._wrapped)
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 42e283073f..99a12d639a 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1369,9 +1369,7 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
cls.tearDownClass()
raise unittest.SkipTest("Hive is not available")
os.unlink(cls.tempdir.name)
- _scala_HiveContext =\
- cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc())
- cls.sqlCtx = HiveContext(cls.sc, _scala_HiveContext)
+ cls.sqlCtx = HiveContext._createForTesting(cls.sc)
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
cls.df = cls.sc.parallelize(cls.testData).toDF()