diff options
author | Sandeep Singh <sandeep@techaddict.me> | 2016-05-11 11:24:16 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-05-11 11:24:16 -0700 |
commit | 29314379729de4082bd2297c9e5289e3e4a0115e (patch) | |
tree | a5aede7207fde856910581f7f97f4b65b73a6e39 /python/pyspark/sql/readwriter.py | |
parent | d8935db5ecb7c959585411da9bf1e9a9c4d5cb37 (diff) | |
download | spark-29314379729de4082bd2297c9e5289e3e4a0115e.tar.gz spark-29314379729de4082bd2297c9e5289e3e4a0115e.tar.bz2 spark-29314379729de4082bd2297c9e5289e3e4a0115e.zip |
[SPARK-15037] [SQL] [MLLIB] Part2: Use SparkSession instead of SQLContext in Python TestSuites
## What changes were proposed in this pull request?
Use SparkSession instead of SQLContext in Python TestSuites
## How was this patch tested?
Existing tests
Author: Sandeep Singh <sandeep@techaddict.me>
Closes #13044 from techaddict/SPARK-15037-python.
Diffstat (limited to 'python/pyspark/sql/readwriter.py')
-rw-r--r-- | python/pyspark/sql/readwriter.py | 72 |
1 files changed, 37 insertions, 35 deletions
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 7e79df33e8..bd728c97c8 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -47,19 +47,19 @@ def to_str(value): class DataFrameReader(object): """ Interface used to load a :class:`DataFrame` from external storage systems - (e.g. file systems, key-value stores, etc). Use :func:`SQLContext.read` + (e.g. file systems, key-value stores, etc). Use :func:`spark.read` to access this. .. versionadded:: 1.4 """ - def __init__(self, sqlContext): - self._jreader = sqlContext._ssql_ctx.read() - self._sqlContext = sqlContext + def __init__(self, spark): + self._jreader = spark._ssql_ctx.read() + self._spark = spark def _df(self, jdf): from pyspark.sql.dataframe import DataFrame - return DataFrame(jdf, self._sqlContext) + return DataFrame(jdf, self._spark) @since(1.4) def format(self, source): @@ -67,7 +67,7 @@ class DataFrameReader(object): :param source: string, name of the data source, e.g. 'json', 'parquet'. - >>> df = sqlContext.read.format('json').load('python/test_support/sql/people.json') + >>> df = spark.read.format('json').load('python/test_support/sql/people.json') >>> df.dtypes [('age', 'bigint'), ('name', 'string')] @@ -87,7 +87,7 @@ class DataFrameReader(object): """ if not isinstance(schema, StructType): raise TypeError("schema should be StructType") - jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json()) + jschema = self._spark._ssql_ctx.parseDataType(schema.json()) self._jreader = self._jreader.schema(jschema) return self @@ -115,12 +115,12 @@ class DataFrameReader(object): :param schema: optional :class:`StructType` for the input schema. :param options: all other string options - >>> df = sqlContext.read.load('python/test_support/sql/parquet_partitioned', opt1=True, + >>> df = spark.read.load('python/test_support/sql/parquet_partitioned', opt1=True, ... opt2=1, opt3='str') >>> df.dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] - >>> df = sqlContext.read.format('json').load(['python/test_support/sql/people.json', + >>> df = spark.read.format('json').load(['python/test_support/sql/people.json', ... 'python/test_support/sql/people1.json']) >>> df.dtypes [('age', 'bigint'), ('aka', 'string'), ('name', 'string')] @@ -133,7 +133,7 @@ class DataFrameReader(object): if path is not None: if type(path) != list: path = [path] - return self._df(self._jreader.load(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) + return self._df(self._jreader.load(self._spark._sc._jvm.PythonUtils.toSeq(path))) else: return self._df(self._jreader.load()) @@ -148,7 +148,7 @@ class DataFrameReader(object): :param schema: optional :class:`StructType` for the input schema. :param options: all other string options - >>> df = sqlContext.read.format('text').stream('python/test_support/sql/streaming') + >>> df = spark.read.format('text').stream('python/test_support/sql/streaming') >>> df.isStreaming True """ @@ -211,11 +211,11 @@ class DataFrameReader(object): ``spark.sql.columnNameOfCorruptRecord``. If None is set, it uses the default value ``_corrupt_record``. - >>> df1 = sqlContext.read.json('python/test_support/sql/people.json') + >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes [('age', 'bigint'), ('name', 'string')] >>> rdd = sc.textFile('python/test_support/sql/people.json') - >>> df2 = sqlContext.read.json(rdd) + >>> df2 = spark.read.json(rdd) >>> df2.dtypes [('age', 'bigint'), ('name', 'string')] @@ -243,7 +243,7 @@ class DataFrameReader(object): if isinstance(path, basestring): path = [path] if type(path) == list: - return self._df(self._jreader.json(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) + return self._df(self._jreader.json(self._spark._sc._jvm.PythonUtils.toSeq(path))) elif isinstance(path, RDD): def func(iterator): for x in iterator: @@ -254,7 +254,7 @@ class DataFrameReader(object): yield x keyed = path.mapPartitions(func) keyed._bypass_serializer = True - jrdd = keyed._jrdd.map(self._sqlContext._jvm.BytesToString()) + jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString()) return self._df(self._jreader.json(jrdd)) else: raise TypeError("path can be only string or RDD") @@ -265,9 +265,9 @@ class DataFrameReader(object): :param tableName: string, name of the table. - >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned') + >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned') >>> df.registerTempTable('tmpTable') - >>> sqlContext.read.table('tmpTable').dtypes + >>> spark.read.table('tmpTable').dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] """ return self._df(self._jreader.table(tableName)) @@ -276,11 +276,11 @@ class DataFrameReader(object): def parquet(self, *paths): """Loads a Parquet file, returning the result as a :class:`DataFrame`. - >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned') + >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned') >>> df.dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] """ - return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, paths))) + return self._df(self._jreader.parquet(_to_seq(self._spark._sc, paths))) @ignore_unicode_prefix @since(1.6) @@ -291,13 +291,13 @@ class DataFrameReader(object): :param paths: string, or list of strings, for input path(s). - >>> df = sqlContext.read.text('python/test_support/sql/text-test.txt') + >>> df = spark.read.text('python/test_support/sql/text-test.txt') >>> df.collect() [Row(value=u'hello'), Row(value=u'this')] """ if isinstance(paths, basestring): path = [paths] - return self._df(self._jreader.text(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) + return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(path))) @since(2.0) def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None, @@ -356,7 +356,7 @@ class DataFrameReader(object): * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. - >>> df = sqlContext.read.csv('python/test_support/sql/ages.csv') + >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes [('C0', 'string'), ('C1', 'string')] """ @@ -396,7 +396,7 @@ class DataFrameReader(object): self.option("mode", mode) if isinstance(path, basestring): path = [path] - return self._df(self._jreader.csv(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) + return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) @since(1.5) def orc(self, path): @@ -441,16 +441,16 @@ class DataFrameReader(object): """ if properties is None: properties = dict() - jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() + jprop = JavaClass("java.util.Properties", self._spark._sc._gateway._gateway_client)() for k in properties: jprop.setProperty(k, properties[k]) if column is not None: if numPartitions is None: - numPartitions = self._sqlContext._sc.defaultParallelism + numPartitions = self._spark._sc.defaultParallelism return self._df(self._jreader.jdbc(url, table, column, int(lowerBound), int(upperBound), int(numPartitions), jprop)) if predicates is not None: - gateway = self._sqlContext._sc._gateway + gateway = self._spark._sc._gateway jpredicates = utils.toJArray(gateway, gateway.jvm.java.lang.String, predicates) return self._df(self._jreader.jdbc(url, table, jpredicates, jprop)) return self._df(self._jreader.jdbc(url, table, jprop)) @@ -466,7 +466,7 @@ class DataFrameWriter(object): """ def __init__(self, df): self._df = df - self._sqlContext = df.sql_ctx + self._spark = df.sql_ctx self._jwrite = df._jdf.write() def _cq(self, jcq): @@ -531,14 +531,14 @@ class DataFrameWriter(object): """ if len(cols) == 1 and isinstance(cols[0], (list, tuple)): cols = cols[0] - self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols)) + self._jwrite = self._jwrite.partitionBy(_to_seq(self._spark._sc, cols)) return self @since(2.0) def queryName(self, queryName): """Specifies the name of the :class:`ContinuousQuery` that can be started with :func:`startStream`. This name must be unique among all the currently active queries - in the associated SQLContext + in the associated spark .. note:: Experimental. @@ -573,7 +573,7 @@ class DataFrameWriter(object): trigger = ProcessingTime(processingTime) if trigger is None: raise ValueError('A trigger was not provided. Supported triggers: processingTime.') - self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._sqlContext)) + self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._spark)) return self @since(1.4) @@ -854,7 +854,7 @@ class DataFrameWriter(object): """ if properties is None: properties = dict() - jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() + jprop = JavaClass("java.util.Properties", self._spark._sc._gateway._gateway_client)() for k in properties: jprop.setProperty(k, properties[k]) self._jwrite.mode(mode).jdbc(url, table, jprop) @@ -865,7 +865,7 @@ def _test(): import os import tempfile from pyspark.context import SparkContext - from pyspark.sql import Row, SQLContext, HiveContext + from pyspark.sql import SparkSession, Row, HiveContext import pyspark.sql.readwriter os.chdir(os.environ["SPARK_HOME"]) @@ -876,11 +876,13 @@ def _test(): globs['tempfile'] = tempfile globs['os'] = os globs['sc'] = sc - globs['sqlContext'] = SQLContext(sc) + globs['spark'] = SparkSession.builder\ + .enableHiveSupport()\ + .getOrCreate() globs['hiveContext'] = HiveContext._createForTesting(sc) - globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned') + globs['df'] = globs['spark'].read.parquet('python/test_support/sql/parquet_partitioned') globs['sdf'] = \ - globs['sqlContext'].read.format('text').stream('python/test_support/sql/streaming') + globs['spark'].read.format('text').stream('python/test_support/sql/streaming') (failure_count, test_count) = doctest.testmod( pyspark.sql.readwriter, globs=globs, |