aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/context.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/context.py')
-rw-r--r--python/pyspark/sql/context.py278
1 files changed, 49 insertions, 229 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 600a6e0bc2..48ffb59668 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -17,70 +17,37 @@
from __future__ import print_function
import sys
-import warnings
-import json
-from functools import reduce
if sys.version >= '3':
basestring = unicode = str
-else:
- from itertools import imap as map
-
-from py4j.protocol import Py4JError
from pyspark import since
-from pyspark.rdd import RDD, ignore_unicode_prefix
-from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
-from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \
- _infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string
+from pyspark.rdd import ignore_unicode_prefix
+from pyspark.sql.session import _monkey_patch_RDD, SparkSession
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.readwriter import DataFrameReader
+from pyspark.sql.types import Row, StringType
from pyspark.sql.utils import install_exception_handler
-from pyspark.sql.functions import UserDefinedFunction
-
-try:
- import pandas
- has_pandas = True
-except Exception:
- has_pandas = False
__all__ = ["SQLContext", "HiveContext", "UDFRegistration"]
-def _monkey_patch_RDD(sqlContext):
- def toDF(self, schema=None, sampleRatio=None):
- """
- Converts current :class:`RDD` into a :class:`DataFrame`
-
- This is a shorthand for ``sqlContext.createDataFrame(rdd, schema, sampleRatio)``
-
- :param schema: a StructType or list of names of columns
- :param samplingRatio: the sample ratio of rows used for inferring
- :return: a DataFrame
-
- >>> rdd.toDF().collect()
- [Row(name=u'Alice', age=1)]
- """
- return sqlContext.createDataFrame(self, schema, sampleRatio)
-
- RDD.toDF = toDF
-
-
class SQLContext(object):
- """Main entry point for Spark SQL functionality.
+ """Wrapper around :class:`SparkSession`, the main entry point to Spark SQL functionality.
A SQLContext can be used create :class:`DataFrame`, register :class:`DataFrame` as
tables, execute SQL over tables, cache tables, and read parquet files.
:param sparkContext: The :class:`SparkContext` backing this SQLContext.
- :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instantiate a new
+ :param sparkSession: The :class:`SparkSession` around which this SQLContext wraps.
+ :param jsqlContext: An optional JVM Scala SQLContext. If set, we do not instantiate a new
SQLContext in the JVM, instead we make all calls to this object.
"""
_instantiatedContext = None
@ignore_unicode_prefix
- def __init__(self, sparkContext, sqlContext=None):
+ def __init__(self, sparkContext, sparkSession=None, jsqlContext=None):
"""Creates a new SQLContext.
>>> from datetime import datetime
@@ -100,8 +67,13 @@ class SQLContext(object):
self._sc = sparkContext
self._jsc = self._sc._jsc
self._jvm = self._sc._jvm
- self._scala_SQLContext = sqlContext
- _monkey_patch_RDD(self)
+ if sparkSession is None:
+ sparkSession = SparkSession(sparkContext)
+ if jsqlContext is None:
+ jsqlContext = sparkSession._jwrapped
+ self.sparkSession = sparkSession
+ self._jsqlContext = jsqlContext
+ _monkey_patch_RDD(self.sparkSession)
install_exception_handler()
if SQLContext._instantiatedContext is None:
SQLContext._instantiatedContext = self
@@ -113,9 +85,7 @@ class SQLContext(object):
Subclasses can override this property to provide their own
JVM Contexts.
"""
- if self._scala_SQLContext is None:
- self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
- return self._scala_SQLContext
+ return self._jsqlContext
@classmethod
@since(1.6)
@@ -127,7 +97,8 @@ class SQLContext(object):
"""
if cls._instantiatedContext is None:
jsqlContext = sc._jvm.SQLContext.getOrCreate(sc._jsc.sc())
- cls(sc, jsqlContext)
+ sparkSession = SparkSession(sc, jsqlContext.sparkSession())
+ cls(sc, sparkSession, jsqlContext)
return cls._instantiatedContext
@since(1.6)
@@ -137,14 +108,13 @@ class SQLContext(object):
registered temporary tables and UDFs, but shared SparkContext and
table cache.
"""
- jsqlContext = self._ssql_ctx.newSession()
- return self.__class__(self._sc, jsqlContext)
+ return self.__class__(self._sc, self.sparkSession.newSession())
@since(1.3)
def setConf(self, key, value):
"""Sets the given Spark SQL configuration property.
"""
- self._ssql_ctx.setConf(key, value)
+ self.sparkSession.setConf(key, value)
@ignore_unicode_prefix
@since(1.3)
@@ -163,10 +133,7 @@ class SQLContext(object):
>>> sqlContext.getConf("spark.sql.shuffle.partitions", "10")
u'50'
"""
- if defaultValue is not None:
- return self._ssql_ctx.getConf(key, defaultValue)
- else:
- return self._ssql_ctx.getConf(key)
+ return self.sparkSession.getConf(key, defaultValue)
@property
@since("1.3.1")
@@ -175,7 +142,7 @@ class SQLContext(object):
:return: :class:`UDFRegistration`
"""
- return UDFRegistration(self)
+ return UDFRegistration(self.sparkSession)
@since(1.4)
def range(self, start, end=None, step=1, numPartitions=None):
@@ -198,15 +165,7 @@ class SQLContext(object):
>>> sqlContext.range(3).collect()
[Row(id=0), Row(id=1), Row(id=2)]
"""
- if numPartitions is None:
- numPartitions = self._sc.defaultParallelism
-
- if end is None:
- jdf = self._ssql_ctx.range(0, int(start), int(step), int(numPartitions))
- else:
- jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions))
-
- return DataFrame(jdf, self)
+ return self.sparkSession.range(start, end, step, numPartitions)
@ignore_unicode_prefix
@since(1.2)
@@ -236,27 +195,9 @@ class SQLContext(object):
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
"""
- udf = UserDefinedFunction(f, returnType, name)
- self._ssql_ctx.udf().registerPython(name, udf._judf)
-
- def _inferSchemaFromList(self, data):
- """
- Infer schema from list of Row or tuple.
-
- :param data: list of Row or tuple
- :return: StructType
- """
- if not data:
- raise ValueError("can not infer schema from empty dataset")
- first = data[0]
- if type(first) is dict:
- warnings.warn("inferring schema from dict is deprecated,"
- "please use pyspark.sql.Row instead")
- schema = reduce(_merge_type, map(_infer_schema, data))
- if _has_nulltype(schema):
- raise ValueError("Some of types cannot be determined after inferring")
- return schema
+ self.sparkSession.registerFunction(name, f, returnType)
+ # TODO(andrew): delete this once we refactor things to take in SparkSession
def _inferSchema(self, rdd, samplingRatio=None):
"""
Infer schema from an RDD of Row or tuple.
@@ -265,78 +206,7 @@ class SQLContext(object):
:param samplingRatio: sampling ratio, or no sampling (default)
:return: StructType
"""
- first = rdd.first()
- if not first:
- raise ValueError("The first row in RDD is empty, "
- "can not infer schema")
- if type(first) is dict:
- warnings.warn("Using RDD of dict to inferSchema is deprecated. "
- "Use pyspark.sql.Row instead")
-
- if samplingRatio is None:
- schema = _infer_schema(first)
- if _has_nulltype(schema):
- for row in rdd.take(100)[1:]:
- schema = _merge_type(schema, _infer_schema(row))
- if not _has_nulltype(schema):
- break
- else:
- raise ValueError("Some of types cannot be determined by the "
- "first 100 rows, please try again with sampling")
- else:
- if samplingRatio < 0.99:
- rdd = rdd.sample(False, float(samplingRatio))
- schema = rdd.map(_infer_schema).reduce(_merge_type)
- return schema
-
- def _createFromRDD(self, rdd, schema, samplingRatio):
- """
- Create an RDD for DataFrame from an existing RDD, returns the RDD and schema.
- """
- if schema is None or isinstance(schema, (list, tuple)):
- struct = self._inferSchema(rdd, samplingRatio)
- converter = _create_converter(struct)
- rdd = rdd.map(converter)
- if isinstance(schema, (list, tuple)):
- for i, name in enumerate(schema):
- struct.fields[i].name = name
- struct.names[i] = name
- schema = struct
-
- elif not isinstance(schema, StructType):
- raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
-
- # convert python objects to sql data
- rdd = rdd.map(schema.toInternal)
- return rdd, schema
-
- def _createFromLocal(self, data, schema):
- """
- Create an RDD for DataFrame from an list or pandas.DataFrame, returns
- the RDD and schema.
- """
- # make sure data could consumed multiple times
- if not isinstance(data, list):
- data = list(data)
-
- if schema is None or isinstance(schema, (list, tuple)):
- struct = self._inferSchemaFromList(data)
- if isinstance(schema, (list, tuple)):
- for i, name in enumerate(schema):
- struct.fields[i].name = name
- struct.names[i] = name
- schema = struct
-
- elif isinstance(schema, StructType):
- for row in data:
- _verify_type(row, schema)
-
- else:
- raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
-
- # convert python objects to sql data
- data = [schema.toInternal(row) for row in data]
- return self._sc.parallelize(data), schema
+ return self.sparkSession._inferSchema(rdd, samplingRatio)
@since(1.3)
@ignore_unicode_prefix
@@ -421,40 +291,7 @@ class SQLContext(object):
...
Py4JJavaError: ...
"""
- if isinstance(data, DataFrame):
- raise TypeError("data is already a DataFrame")
-
- if isinstance(schema, basestring):
- schema = _parse_datatype_string(schema)
-
- if has_pandas and isinstance(data, pandas.DataFrame):
- if schema is None:
- schema = [str(x) for x in data.columns]
- data = [r.tolist() for r in data.to_records(index=False)]
-
- if isinstance(schema, StructType):
- def prepare(obj):
- _verify_type(obj, schema)
- return obj
- elif isinstance(schema, DataType):
- datatype = schema
-
- def prepare(obj):
- _verify_type(obj, datatype)
- return (obj, )
- schema = StructType().add("value", datatype)
- else:
- prepare = lambda obj: obj
-
- if isinstance(data, RDD):
- rdd, schema = self._createFromRDD(data.map(prepare), schema, samplingRatio)
- else:
- rdd, schema = self._createFromLocal(map(prepare, data), schema)
- jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
- jdf = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
- df = DataFrame(jdf, self)
- df._schema = schema
- return df
+ return self.sparkSession.createDataFrame(data, schema, samplingRatio)
@since(1.3)
def registerDataFrameAsTable(self, df, tableName):
@@ -464,10 +301,7 @@ class SQLContext(object):
>>> sqlContext.registerDataFrameAsTable(df, "table1")
"""
- if (df.__class__ is DataFrame):
- self._ssql_ctx.registerDataFrameAsTable(df._jdf, tableName)
- else:
- raise ValueError("Can only register DataFrame as table")
+ self.sparkSession.registerDataFrameAsTable(df, tableName)
@since(1.6)
def dropTempTable(self, tableName):
@@ -493,20 +327,7 @@ class SQLContext(object):
:return: :class:`DataFrame`
"""
- if path is not None:
- options["path"] = path
- if source is None:
- source = self.getConf("spark.sql.sources.default",
- "org.apache.spark.sql.parquet")
- if schema is None:
- df = self._ssql_ctx.createExternalTable(tableName, source, options)
- else:
- if not isinstance(schema, StructType):
- raise TypeError("schema should be StructType")
- scala_datatype = self._ssql_ctx.parseDataType(schema.json())
- df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype,
- options)
- return DataFrame(df, self)
+ return self.sparkSession.createExternalTable(tableName, path, source, schema, **options)
@ignore_unicode_prefix
@since(1.0)
@@ -520,7 +341,7 @@ class SQLContext(object):
>>> df2.collect()
[Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
"""
- return DataFrame(self._ssql_ctx.sql(sqlQuery), self)
+ return self.sparkSession.sql(sqlQuery)
@since(1.0)
def table(self, tableName):
@@ -533,7 +354,7 @@ class SQLContext(object):
>>> sorted(df.collect()) == sorted(df2.collect())
True
"""
- return DataFrame(self._ssql_ctx.table(tableName), self)
+ return self.sparkSession.table(tableName)
@ignore_unicode_prefix
@since(1.3)
@@ -603,7 +424,7 @@ class SQLContext(object):
return DataFrameReader(self)
-# TODO(andrew): remove this too
+# TODO(andrew): deprecate this
class HiveContext(SQLContext):
"""A variant of Spark SQL that integrates with data stored in Hive.
@@ -611,29 +432,28 @@ class HiveContext(SQLContext):
It supports running both SQL and HiveQL commands.
:param sparkContext: The SparkContext to wrap.
- :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instantiate a new
+ :param jhiveContext: An optional JVM Scala HiveContext. If set, we do not instantiate a new
:class:`HiveContext` in the JVM, instead we make all calls to this object.
"""
- def __init__(self, sparkContext, hiveContext=None):
- SQLContext.__init__(self, sparkContext)
- if hiveContext:
- self._scala_HiveContext = hiveContext
+ def __init__(self, sparkContext, jhiveContext=None):
+ if jhiveContext is None:
+ sparkSession = SparkSession.withHiveSupport(sparkContext)
+ else:
+ sparkSession = SparkSession(sparkContext, jhiveContext.sparkSession())
+ SQLContext.__init__(self, sparkContext, sparkSession, jhiveContext)
+
+ @classmethod
+ def _createForTesting(cls, sparkContext):
+ """(Internal use only) Create a new HiveContext for testing.
- @property
- def _ssql_ctx(self):
- try:
- if not hasattr(self, '_scala_HiveContext'):
- self._scala_HiveContext = self._get_hive_ctx()
- return self._scala_HiveContext
- except Py4JError as e:
- print("You must build Spark with Hive. "
- "Export 'SPARK_HIVE=true' and run "
- "build/sbt assembly", file=sys.stderr)
- raise
-
- def _get_hive_ctx(self):
- return self._jvm.SparkSession.withHiveSupport(self._jsc.sc()).wrapped()
+ All test code that touches HiveContext *must* go through this method. Otherwise,
+ you may end up launching multiple derby instances and encounter with incredibly
+ confusing error messages.
+ """
+ jsc = sparkContext._jsc.sc()
+ jtestHive = sparkContext._jvm.org.apache.spark.sql.hive.test.TestHiveContext(jsc)
+ return cls(sparkContext, jtestHive)
def refreshTable(self, tableName):
"""Invalidate and refresh all the cached the metadata of the given