aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/readwriter.py
diff options
context:
space:
mode:
authorSandeep Singh <sandeep@techaddict.me>2016-05-11 11:24:16 -0700
committerDavies Liu <davies.liu@gmail.com>2016-05-11 11:24:16 -0700
commit29314379729de4082bd2297c9e5289e3e4a0115e (patch)
treea5aede7207fde856910581f7f97f4b65b73a6e39 /python/pyspark/sql/readwriter.py
parentd8935db5ecb7c959585411da9bf1e9a9c4d5cb37 (diff)
downloadspark-29314379729de4082bd2297c9e5289e3e4a0115e.tar.gz
spark-29314379729de4082bd2297c9e5289e3e4a0115e.tar.bz2
spark-29314379729de4082bd2297c9e5289e3e4a0115e.zip
[SPARK-15037] [SQL] [MLLIB] Part2: Use SparkSession instead of SQLContext in Python TestSuites
## What changes were proposed in this pull request? Use SparkSession instead of SQLContext in Python TestSuites ## How was this patch tested? Existing tests Author: Sandeep Singh <sandeep@techaddict.me> Closes #13044 from techaddict/SPARK-15037-python.
Diffstat (limited to 'python/pyspark/sql/readwriter.py')
-rw-r--r--python/pyspark/sql/readwriter.py72
1 files changed, 37 insertions, 35 deletions
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 7e79df33e8..bd728c97c8 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -47,19 +47,19 @@ def to_str(value):
class DataFrameReader(object):
"""
Interface used to load a :class:`DataFrame` from external storage systems
- (e.g. file systems, key-value stores, etc). Use :func:`SQLContext.read`
+ (e.g. file systems, key-value stores, etc). Use :func:`spark.read`
to access this.
.. versionadded:: 1.4
"""
- def __init__(self, sqlContext):
- self._jreader = sqlContext._ssql_ctx.read()
- self._sqlContext = sqlContext
+ def __init__(self, spark):
+ self._jreader = spark._ssql_ctx.read()
+ self._spark = spark
def _df(self, jdf):
from pyspark.sql.dataframe import DataFrame
- return DataFrame(jdf, self._sqlContext)
+ return DataFrame(jdf, self._spark)
@since(1.4)
def format(self, source):
@@ -67,7 +67,7 @@ class DataFrameReader(object):
:param source: string, name of the data source, e.g. 'json', 'parquet'.
- >>> df = sqlContext.read.format('json').load('python/test_support/sql/people.json')
+ >>> df = spark.read.format('json').load('python/test_support/sql/people.json')
>>> df.dtypes
[('age', 'bigint'), ('name', 'string')]
@@ -87,7 +87,7 @@ class DataFrameReader(object):
"""
if not isinstance(schema, StructType):
raise TypeError("schema should be StructType")
- jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
+ jschema = self._spark._ssql_ctx.parseDataType(schema.json())
self._jreader = self._jreader.schema(jschema)
return self
@@ -115,12 +115,12 @@ class DataFrameReader(object):
:param schema: optional :class:`StructType` for the input schema.
:param options: all other string options
- >>> df = sqlContext.read.load('python/test_support/sql/parquet_partitioned', opt1=True,
+ >>> df = spark.read.load('python/test_support/sql/parquet_partitioned', opt1=True,
... opt2=1, opt3='str')
>>> df.dtypes
[('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
- >>> df = sqlContext.read.format('json').load(['python/test_support/sql/people.json',
+ >>> df = spark.read.format('json').load(['python/test_support/sql/people.json',
... 'python/test_support/sql/people1.json'])
>>> df.dtypes
[('age', 'bigint'), ('aka', 'string'), ('name', 'string')]
@@ -133,7 +133,7 @@ class DataFrameReader(object):
if path is not None:
if type(path) != list:
path = [path]
- return self._df(self._jreader.load(self._sqlContext._sc._jvm.PythonUtils.toSeq(path)))
+ return self._df(self._jreader.load(self._spark._sc._jvm.PythonUtils.toSeq(path)))
else:
return self._df(self._jreader.load())
@@ -148,7 +148,7 @@ class DataFrameReader(object):
:param schema: optional :class:`StructType` for the input schema.
:param options: all other string options
- >>> df = sqlContext.read.format('text').stream('python/test_support/sql/streaming')
+ >>> df = spark.read.format('text').stream('python/test_support/sql/streaming')
>>> df.isStreaming
True
"""
@@ -211,11 +211,11 @@ class DataFrameReader(object):
``spark.sql.columnNameOfCorruptRecord``. If None is set,
it uses the default value ``_corrupt_record``.
- >>> df1 = sqlContext.read.json('python/test_support/sql/people.json')
+ >>> df1 = spark.read.json('python/test_support/sql/people.json')
>>> df1.dtypes
[('age', 'bigint'), ('name', 'string')]
>>> rdd = sc.textFile('python/test_support/sql/people.json')
- >>> df2 = sqlContext.read.json(rdd)
+ >>> df2 = spark.read.json(rdd)
>>> df2.dtypes
[('age', 'bigint'), ('name', 'string')]
@@ -243,7 +243,7 @@ class DataFrameReader(object):
if isinstance(path, basestring):
path = [path]
if type(path) == list:
- return self._df(self._jreader.json(self._sqlContext._sc._jvm.PythonUtils.toSeq(path)))
+ return self._df(self._jreader.json(self._spark._sc._jvm.PythonUtils.toSeq(path)))
elif isinstance(path, RDD):
def func(iterator):
for x in iterator:
@@ -254,7 +254,7 @@ class DataFrameReader(object):
yield x
keyed = path.mapPartitions(func)
keyed._bypass_serializer = True
- jrdd = keyed._jrdd.map(self._sqlContext._jvm.BytesToString())
+ jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString())
return self._df(self._jreader.json(jrdd))
else:
raise TypeError("path can be only string or RDD")
@@ -265,9 +265,9 @@ class DataFrameReader(object):
:param tableName: string, name of the table.
- >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned')
+ >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned')
>>> df.registerTempTable('tmpTable')
- >>> sqlContext.read.table('tmpTable').dtypes
+ >>> spark.read.table('tmpTable').dtypes
[('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
"""
return self._df(self._jreader.table(tableName))
@@ -276,11 +276,11 @@ class DataFrameReader(object):
def parquet(self, *paths):
"""Loads a Parquet file, returning the result as a :class:`DataFrame`.
- >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned')
+ >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned')
>>> df.dtypes
[('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
"""
- return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, paths)))
+ return self._df(self._jreader.parquet(_to_seq(self._spark._sc, paths)))
@ignore_unicode_prefix
@since(1.6)
@@ -291,13 +291,13 @@ class DataFrameReader(object):
:param paths: string, or list of strings, for input path(s).
- >>> df = sqlContext.read.text('python/test_support/sql/text-test.txt')
+ >>> df = spark.read.text('python/test_support/sql/text-test.txt')
>>> df.collect()
[Row(value=u'hello'), Row(value=u'this')]
"""
if isinstance(paths, basestring):
path = [paths]
- return self._df(self._jreader.text(self._sqlContext._sc._jvm.PythonUtils.toSeq(path)))
+ return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(path)))
@since(2.0)
def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None,
@@ -356,7 +356,7 @@ class DataFrameReader(object):
* ``DROPMALFORMED`` : ignores the whole corrupted records.
* ``FAILFAST`` : throws an exception when it meets corrupted records.
- >>> df = sqlContext.read.csv('python/test_support/sql/ages.csv')
+ >>> df = spark.read.csv('python/test_support/sql/ages.csv')
>>> df.dtypes
[('C0', 'string'), ('C1', 'string')]
"""
@@ -396,7 +396,7 @@ class DataFrameReader(object):
self.option("mode", mode)
if isinstance(path, basestring):
path = [path]
- return self._df(self._jreader.csv(self._sqlContext._sc._jvm.PythonUtils.toSeq(path)))
+ return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path)))
@since(1.5)
def orc(self, path):
@@ -441,16 +441,16 @@ class DataFrameReader(object):
"""
if properties is None:
properties = dict()
- jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)()
+ jprop = JavaClass("java.util.Properties", self._spark._sc._gateway._gateway_client)()
for k in properties:
jprop.setProperty(k, properties[k])
if column is not None:
if numPartitions is None:
- numPartitions = self._sqlContext._sc.defaultParallelism
+ numPartitions = self._spark._sc.defaultParallelism
return self._df(self._jreader.jdbc(url, table, column, int(lowerBound), int(upperBound),
int(numPartitions), jprop))
if predicates is not None:
- gateway = self._sqlContext._sc._gateway
+ gateway = self._spark._sc._gateway
jpredicates = utils.toJArray(gateway, gateway.jvm.java.lang.String, predicates)
return self._df(self._jreader.jdbc(url, table, jpredicates, jprop))
return self._df(self._jreader.jdbc(url, table, jprop))
@@ -466,7 +466,7 @@ class DataFrameWriter(object):
"""
def __init__(self, df):
self._df = df
- self._sqlContext = df.sql_ctx
+ self._spark = df.sql_ctx
self._jwrite = df._jdf.write()
def _cq(self, jcq):
@@ -531,14 +531,14 @@ class DataFrameWriter(object):
"""
if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
cols = cols[0]
- self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols))
+ self._jwrite = self._jwrite.partitionBy(_to_seq(self._spark._sc, cols))
return self
@since(2.0)
def queryName(self, queryName):
"""Specifies the name of the :class:`ContinuousQuery` that can be started with
:func:`startStream`. This name must be unique among all the currently active queries
- in the associated SQLContext
+ in the associated spark
.. note:: Experimental.
@@ -573,7 +573,7 @@ class DataFrameWriter(object):
trigger = ProcessingTime(processingTime)
if trigger is None:
raise ValueError('A trigger was not provided. Supported triggers: processingTime.')
- self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._sqlContext))
+ self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._spark))
return self
@since(1.4)
@@ -854,7 +854,7 @@ class DataFrameWriter(object):
"""
if properties is None:
properties = dict()
- jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)()
+ jprop = JavaClass("java.util.Properties", self._spark._sc._gateway._gateway_client)()
for k in properties:
jprop.setProperty(k, properties[k])
self._jwrite.mode(mode).jdbc(url, table, jprop)
@@ -865,7 +865,7 @@ def _test():
import os
import tempfile
from pyspark.context import SparkContext
- from pyspark.sql import Row, SQLContext, HiveContext
+ from pyspark.sql import SparkSession, Row, HiveContext
import pyspark.sql.readwriter
os.chdir(os.environ["SPARK_HOME"])
@@ -876,11 +876,13 @@ def _test():
globs['tempfile'] = tempfile
globs['os'] = os
globs['sc'] = sc
- globs['sqlContext'] = SQLContext(sc)
+ globs['spark'] = SparkSession.builder\
+ .enableHiveSupport()\
+ .getOrCreate()
globs['hiveContext'] = HiveContext._createForTesting(sc)
- globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned')
+ globs['df'] = globs['spark'].read.parquet('python/test_support/sql/parquet_partitioned')
globs['sdf'] = \
- globs['sqlContext'].read.format('text').stream('python/test_support/sql/streaming')
+ globs['spark'].read.format('text').stream('python/test_support/sql/streaming')
(failure_count, test_count) = doctest.testmod(
pyspark.sql.readwriter, globs=globs,