diff options
Diffstat (limited to 'python/pyspark/sql/context.py')
-rw-r--r-- | python/pyspark/sql/context.py | 278 |
1 files changed, 49 insertions, 229 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 600a6e0bc2..48ffb59668 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -17,70 +17,37 @@ from __future__ import print_function import sys -import warnings -import json -from functools import reduce if sys.version >= '3': basestring = unicode = str -else: - from itertools import imap as map - -from py4j.protocol import Py4JError from pyspark import since -from pyspark.rdd import RDD, ignore_unicode_prefix -from pyspark.serializers import AutoBatchedSerializer, PickleSerializer -from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \ - _infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string +from pyspark.rdd import ignore_unicode_prefix +from pyspark.sql.session import _monkey_patch_RDD, SparkSession from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader +from pyspark.sql.types import Row, StringType from pyspark.sql.utils import install_exception_handler -from pyspark.sql.functions import UserDefinedFunction - -try: - import pandas - has_pandas = True -except Exception: - has_pandas = False __all__ = ["SQLContext", "HiveContext", "UDFRegistration"] -def _monkey_patch_RDD(sqlContext): - def toDF(self, schema=None, sampleRatio=None): - """ - Converts current :class:`RDD` into a :class:`DataFrame` - - This is a shorthand for ``sqlContext.createDataFrame(rdd, schema, sampleRatio)`` - - :param schema: a StructType or list of names of columns - :param samplingRatio: the sample ratio of rows used for inferring - :return: a DataFrame - - >>> rdd.toDF().collect() - [Row(name=u'Alice', age=1)] - """ - return sqlContext.createDataFrame(self, schema, sampleRatio) - - RDD.toDF = toDF - - class SQLContext(object): - """Main entry point for Spark SQL functionality. + """Wrapper around :class:`SparkSession`, the main entry point to Spark SQL functionality. A SQLContext can be used create :class:`DataFrame`, register :class:`DataFrame` as tables, execute SQL over tables, cache tables, and read parquet files. :param sparkContext: The :class:`SparkContext` backing this SQLContext. - :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instantiate a new + :param sparkSession: The :class:`SparkSession` around which this SQLContext wraps. + :param jsqlContext: An optional JVM Scala SQLContext. If set, we do not instantiate a new SQLContext in the JVM, instead we make all calls to this object. """ _instantiatedContext = None @ignore_unicode_prefix - def __init__(self, sparkContext, sqlContext=None): + def __init__(self, sparkContext, sparkSession=None, jsqlContext=None): """Creates a new SQLContext. >>> from datetime import datetime @@ -100,8 +67,13 @@ class SQLContext(object): self._sc = sparkContext self._jsc = self._sc._jsc self._jvm = self._sc._jvm - self._scala_SQLContext = sqlContext - _monkey_patch_RDD(self) + if sparkSession is None: + sparkSession = SparkSession(sparkContext) + if jsqlContext is None: + jsqlContext = sparkSession._jwrapped + self.sparkSession = sparkSession + self._jsqlContext = jsqlContext + _monkey_patch_RDD(self.sparkSession) install_exception_handler() if SQLContext._instantiatedContext is None: SQLContext._instantiatedContext = self @@ -113,9 +85,7 @@ class SQLContext(object): Subclasses can override this property to provide their own JVM Contexts. """ - if self._scala_SQLContext is None: - self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) - return self._scala_SQLContext + return self._jsqlContext @classmethod @since(1.6) @@ -127,7 +97,8 @@ class SQLContext(object): """ if cls._instantiatedContext is None: jsqlContext = sc._jvm.SQLContext.getOrCreate(sc._jsc.sc()) - cls(sc, jsqlContext) + sparkSession = SparkSession(sc, jsqlContext.sparkSession()) + cls(sc, sparkSession, jsqlContext) return cls._instantiatedContext @since(1.6) @@ -137,14 +108,13 @@ class SQLContext(object): registered temporary tables and UDFs, but shared SparkContext and table cache. """ - jsqlContext = self._ssql_ctx.newSession() - return self.__class__(self._sc, jsqlContext) + return self.__class__(self._sc, self.sparkSession.newSession()) @since(1.3) def setConf(self, key, value): """Sets the given Spark SQL configuration property. """ - self._ssql_ctx.setConf(key, value) + self.sparkSession.setConf(key, value) @ignore_unicode_prefix @since(1.3) @@ -163,10 +133,7 @@ class SQLContext(object): >>> sqlContext.getConf("spark.sql.shuffle.partitions", "10") u'50' """ - if defaultValue is not None: - return self._ssql_ctx.getConf(key, defaultValue) - else: - return self._ssql_ctx.getConf(key) + return self.sparkSession.getConf(key, defaultValue) @property @since("1.3.1") @@ -175,7 +142,7 @@ class SQLContext(object): :return: :class:`UDFRegistration` """ - return UDFRegistration(self) + return UDFRegistration(self.sparkSession) @since(1.4) def range(self, start, end=None, step=1, numPartitions=None): @@ -198,15 +165,7 @@ class SQLContext(object): >>> sqlContext.range(3).collect() [Row(id=0), Row(id=1), Row(id=2)] """ - if numPartitions is None: - numPartitions = self._sc.defaultParallelism - - if end is None: - jdf = self._ssql_ctx.range(0, int(start), int(step), int(numPartitions)) - else: - jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions)) - - return DataFrame(jdf, self) + return self.sparkSession.range(start, end, step, numPartitions) @ignore_unicode_prefix @since(1.2) @@ -236,27 +195,9 @@ class SQLContext(object): >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] """ - udf = UserDefinedFunction(f, returnType, name) - self._ssql_ctx.udf().registerPython(name, udf._judf) - - def _inferSchemaFromList(self, data): - """ - Infer schema from list of Row or tuple. - - :param data: list of Row or tuple - :return: StructType - """ - if not data: - raise ValueError("can not infer schema from empty dataset") - first = data[0] - if type(first) is dict: - warnings.warn("inferring schema from dict is deprecated," - "please use pyspark.sql.Row instead") - schema = reduce(_merge_type, map(_infer_schema, data)) - if _has_nulltype(schema): - raise ValueError("Some of types cannot be determined after inferring") - return schema + self.sparkSession.registerFunction(name, f, returnType) + # TODO(andrew): delete this once we refactor things to take in SparkSession def _inferSchema(self, rdd, samplingRatio=None): """ Infer schema from an RDD of Row or tuple. @@ -265,78 +206,7 @@ class SQLContext(object): :param samplingRatio: sampling ratio, or no sampling (default) :return: StructType """ - first = rdd.first() - if not first: - raise ValueError("The first row in RDD is empty, " - "can not infer schema") - if type(first) is dict: - warnings.warn("Using RDD of dict to inferSchema is deprecated. " - "Use pyspark.sql.Row instead") - - if samplingRatio is None: - schema = _infer_schema(first) - if _has_nulltype(schema): - for row in rdd.take(100)[1:]: - schema = _merge_type(schema, _infer_schema(row)) - if not _has_nulltype(schema): - break - else: - raise ValueError("Some of types cannot be determined by the " - "first 100 rows, please try again with sampling") - else: - if samplingRatio < 0.99: - rdd = rdd.sample(False, float(samplingRatio)) - schema = rdd.map(_infer_schema).reduce(_merge_type) - return schema - - def _createFromRDD(self, rdd, schema, samplingRatio): - """ - Create an RDD for DataFrame from an existing RDD, returns the RDD and schema. - """ - if schema is None or isinstance(schema, (list, tuple)): - struct = self._inferSchema(rdd, samplingRatio) - converter = _create_converter(struct) - rdd = rdd.map(converter) - if isinstance(schema, (list, tuple)): - for i, name in enumerate(schema): - struct.fields[i].name = name - struct.names[i] = name - schema = struct - - elif not isinstance(schema, StructType): - raise TypeError("schema should be StructType or list or None, but got: %s" % schema) - - # convert python objects to sql data - rdd = rdd.map(schema.toInternal) - return rdd, schema - - def _createFromLocal(self, data, schema): - """ - Create an RDD for DataFrame from an list or pandas.DataFrame, returns - the RDD and schema. - """ - # make sure data could consumed multiple times - if not isinstance(data, list): - data = list(data) - - if schema is None or isinstance(schema, (list, tuple)): - struct = self._inferSchemaFromList(data) - if isinstance(schema, (list, tuple)): - for i, name in enumerate(schema): - struct.fields[i].name = name - struct.names[i] = name - schema = struct - - elif isinstance(schema, StructType): - for row in data: - _verify_type(row, schema) - - else: - raise TypeError("schema should be StructType or list or None, but got: %s" % schema) - - # convert python objects to sql data - data = [schema.toInternal(row) for row in data] - return self._sc.parallelize(data), schema + return self.sparkSession._inferSchema(rdd, samplingRatio) @since(1.3) @ignore_unicode_prefix @@ -421,40 +291,7 @@ class SQLContext(object): ... Py4JJavaError: ... """ - if isinstance(data, DataFrame): - raise TypeError("data is already a DataFrame") - - if isinstance(schema, basestring): - schema = _parse_datatype_string(schema) - - if has_pandas and isinstance(data, pandas.DataFrame): - if schema is None: - schema = [str(x) for x in data.columns] - data = [r.tolist() for r in data.to_records(index=False)] - - if isinstance(schema, StructType): - def prepare(obj): - _verify_type(obj, schema) - return obj - elif isinstance(schema, DataType): - datatype = schema - - def prepare(obj): - _verify_type(obj, datatype) - return (obj, ) - schema = StructType().add("value", datatype) - else: - prepare = lambda obj: obj - - if isinstance(data, RDD): - rdd, schema = self._createFromRDD(data.map(prepare), schema, samplingRatio) - else: - rdd, schema = self._createFromLocal(map(prepare, data), schema) - jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) - jdf = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) - df = DataFrame(jdf, self) - df._schema = schema - return df + return self.sparkSession.createDataFrame(data, schema, samplingRatio) @since(1.3) def registerDataFrameAsTable(self, df, tableName): @@ -464,10 +301,7 @@ class SQLContext(object): >>> sqlContext.registerDataFrameAsTable(df, "table1") """ - if (df.__class__ is DataFrame): - self._ssql_ctx.registerDataFrameAsTable(df._jdf, tableName) - else: - raise ValueError("Can only register DataFrame as table") + self.sparkSession.registerDataFrameAsTable(df, tableName) @since(1.6) def dropTempTable(self, tableName): @@ -493,20 +327,7 @@ class SQLContext(object): :return: :class:`DataFrame` """ - if path is not None: - options["path"] = path - if source is None: - source = self.getConf("spark.sql.sources.default", - "org.apache.spark.sql.parquet") - if schema is None: - df = self._ssql_ctx.createExternalTable(tableName, source, options) - else: - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType") - scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype, - options) - return DataFrame(df, self) + return self.sparkSession.createExternalTable(tableName, path, source, schema, **options) @ignore_unicode_prefix @since(1.0) @@ -520,7 +341,7 @@ class SQLContext(object): >>> df2.collect() [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] """ - return DataFrame(self._ssql_ctx.sql(sqlQuery), self) + return self.sparkSession.sql(sqlQuery) @since(1.0) def table(self, tableName): @@ -533,7 +354,7 @@ class SQLContext(object): >>> sorted(df.collect()) == sorted(df2.collect()) True """ - return DataFrame(self._ssql_ctx.table(tableName), self) + return self.sparkSession.table(tableName) @ignore_unicode_prefix @since(1.3) @@ -603,7 +424,7 @@ class SQLContext(object): return DataFrameReader(self) -# TODO(andrew): remove this too +# TODO(andrew): deprecate this class HiveContext(SQLContext): """A variant of Spark SQL that integrates with data stored in Hive. @@ -611,29 +432,28 @@ class HiveContext(SQLContext): It supports running both SQL and HiveQL commands. :param sparkContext: The SparkContext to wrap. - :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instantiate a new + :param jhiveContext: An optional JVM Scala HiveContext. If set, we do not instantiate a new :class:`HiveContext` in the JVM, instead we make all calls to this object. """ - def __init__(self, sparkContext, hiveContext=None): - SQLContext.__init__(self, sparkContext) - if hiveContext: - self._scala_HiveContext = hiveContext + def __init__(self, sparkContext, jhiveContext=None): + if jhiveContext is None: + sparkSession = SparkSession.withHiveSupport(sparkContext) + else: + sparkSession = SparkSession(sparkContext, jhiveContext.sparkSession()) + SQLContext.__init__(self, sparkContext, sparkSession, jhiveContext) + + @classmethod + def _createForTesting(cls, sparkContext): + """(Internal use only) Create a new HiveContext for testing. - @property - def _ssql_ctx(self): - try: - if not hasattr(self, '_scala_HiveContext'): - self._scala_HiveContext = self._get_hive_ctx() - return self._scala_HiveContext - except Py4JError as e: - print("You must build Spark with Hive. " - "Export 'SPARK_HIVE=true' and run " - "build/sbt assembly", file=sys.stderr) - raise - - def _get_hive_ctx(self): - return self._jvm.SparkSession.withHiveSupport(self._jsc.sc()).wrapped() + All test code that touches HiveContext *must* go through this method. Otherwise, + you may end up launching multiple derby instances and encounter with incredibly + confusing error messages. + """ + jsc = sparkContext._jsc.sc() + jtestHive = sparkContext._jvm.org.apache.spark.sql.hive.test.TestHiveContext(jsc) + return cls(sparkContext, jtestHive) def refreshTable(self, tableName): """Invalidate and refresh all the cached the metadata of the given |