aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorSandeep Singh <sandeep@techaddict.me>2016-05-11 11:24:16 -0700
committerDavies Liu <davies.liu@gmail.com>2016-05-11 11:24:16 -0700
commit29314379729de4082bd2297c9e5289e3e4a0115e (patch)
treea5aede7207fde856910581f7f97f4b65b73a6e39 /python
parentd8935db5ecb7c959585411da9bf1e9a9c4d5cb37 (diff)
downloadspark-29314379729de4082bd2297c9e5289e3e4a0115e.tar.gz
spark-29314379729de4082bd2297c9e5289e3e4a0115e.tar.bz2
spark-29314379729de4082bd2297c9e5289e3e4a0115e.zip
[SPARK-15037] [SQL] [MLLIB] Part2: Use SparkSession instead of SQLContext in Python TestSuites
## What changes were proposed in this pull request? Use SparkSession instead of SQLContext in Python TestSuites ## How was this patch tested? Existing tests Author: Sandeep Singh <sandeep@techaddict.me> Closes #13044 from techaddict/SPARK-15037-python.
Diffstat (limited to 'python')
-rwxr-xr-xpython/pyspark/ml/tests.py97
-rw-r--r--python/pyspark/mllib/tests.py19
-rw-r--r--python/pyspark/sql/readwriter.py72
-rw-r--r--python/pyspark/sql/tests.py379
4 files changed, 273 insertions, 294 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index ad1631fb5b..49d3a4a332 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -57,13 +57,25 @@ from pyspark.ml.tuning import *
from pyspark.ml.wrapper import JavaParams
from pyspark.mllib.common import _java2py
from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector
-from pyspark.sql import DataFrame, SQLContext, Row
+from pyspark.sql import DataFrame, Row, SparkSession
from pyspark.sql.functions import rand
from pyspark.sql.utils import IllegalArgumentException
from pyspark.storagelevel import *
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
+class SparkSessionTestCase(PySparkTestCase):
+ @classmethod
+ def setUpClass(cls):
+ PySparkTestCase.setUpClass()
+ cls.spark = SparkSession(cls.sc)
+
+ @classmethod
+ def tearDownClass(cls):
+ PySparkTestCase.tearDownClass()
+ cls.spark.stop()
+
+
class MockDataset(DataFrame):
def __init__(self):
@@ -350,7 +362,7 @@ class ParamTests(PySparkTestCase):
self.assertEqual(model.getWindowSize(), 6)
-class FeatureTests(PySparkTestCase):
+class FeatureTests(SparkSessionTestCase):
def test_binarizer(self):
b0 = Binarizer()
@@ -376,8 +388,7 @@ class FeatureTests(PySparkTestCase):
self.assertEqual(b1.getOutputCol(), "output")
def test_idf(self):
- sqlContext = SQLContext(self.sc)
- dataset = sqlContext.createDataFrame([
+ dataset = self.spark.createDataFrame([
(DenseVector([1.0, 2.0]),),
(DenseVector([0.0, 1.0]),),
(DenseVector([3.0, 0.2]),)], ["tf"])
@@ -390,8 +401,7 @@ class FeatureTests(PySparkTestCase):
self.assertIsNotNone(output.head().idf)
def test_ngram(self):
- sqlContext = SQLContext(self.sc)
- dataset = sqlContext.createDataFrame([
+ dataset = self.spark.createDataFrame([
Row(input=["a", "b", "c", "d", "e"])])
ngram0 = NGram(n=4, inputCol="input", outputCol="output")
self.assertEqual(ngram0.getN(), 4)
@@ -401,8 +411,7 @@ class FeatureTests(PySparkTestCase):
self.assertEqual(transformedDF.head().output, ["a b c d", "b c d e"])
def test_stopwordsremover(self):
- sqlContext = SQLContext(self.sc)
- dataset = sqlContext.createDataFrame([Row(input=["a", "panda"])])
+ dataset = self.spark.createDataFrame([Row(input=["a", "panda"])])
stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output")
# Default
self.assertEqual(stopWordRemover.getInputCol(), "input")
@@ -419,15 +428,14 @@ class FeatureTests(PySparkTestCase):
self.assertEqual(transformedDF.head().output, ["a"])
# with language selection
stopwords = StopWordsRemover.loadDefaultStopWords("turkish")
- dataset = sqlContext.createDataFrame([Row(input=["acaba", "ama", "biri"])])
+ dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", "biri"])])
stopWordRemover.setStopWords(stopwords)
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, [])
def test_count_vectorizer_with_binary(self):
- sqlContext = SQLContext(self.sc)
- dataset = sqlContext.createDataFrame([
+ dataset = self.spark.createDataFrame([
(0, "a a a b b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),),
(1, "a a".split(' '), SparseVector(3, {0: 1.0}),),
(2, "a b".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),),
@@ -475,11 +483,10 @@ class InducedErrorEstimator(Estimator, HasInducedError):
return model
-class CrossValidatorTests(PySparkTestCase):
+class CrossValidatorTests(SparkSessionTestCase):
def test_copy(self):
- sqlContext = SQLContext(self.sc)
- dataset = sqlContext.createDataFrame([
+ dataset = self.spark.createDataFrame([
(10, 10.0),
(50, 50.0),
(100, 100.0),
@@ -503,8 +510,7 @@ class CrossValidatorTests(PySparkTestCase):
< 0.0001)
def test_fit_minimize_metric(self):
- sqlContext = SQLContext(self.sc)
- dataset = sqlContext.createDataFrame([
+ dataset = self.spark.createDataFrame([
(10, 10.0),
(50, 50.0),
(100, 100.0),
@@ -527,8 +533,7 @@ class CrossValidatorTests(PySparkTestCase):
self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0")
def test_fit_maximize_metric(self):
- sqlContext = SQLContext(self.sc)
- dataset = sqlContext.createDataFrame([
+ dataset = self.spark.createDataFrame([
(10, 10.0),
(50, 50.0),
(100, 100.0),
@@ -554,8 +559,7 @@ class CrossValidatorTests(PySparkTestCase):
# This tests saving and loading the trained model only.
# Save/load for CrossValidator will be added later: SPARK-13786
temp_path = tempfile.mkdtemp()
- sqlContext = SQLContext(self.sc)
- dataset = sqlContext.createDataFrame(
+ dataset = self.spark.createDataFrame(
[(Vectors.dense([0.0]), 0.0),
(Vectors.dense([0.4]), 1.0),
(Vectors.dense([0.5]), 0.0),
@@ -576,11 +580,10 @@ class CrossValidatorTests(PySparkTestCase):
self.assertEqual(loadedLrModel.intercept, lrModel.intercept)
-class TrainValidationSplitTests(PySparkTestCase):
+class TrainValidationSplitTests(SparkSessionTestCase):
def test_fit_minimize_metric(self):
- sqlContext = SQLContext(self.sc)
- dataset = sqlContext.createDataFrame([
+ dataset = self.spark.createDataFrame([
(10, 10.0),
(50, 50.0),
(100, 100.0),
@@ -603,8 +606,7 @@ class TrainValidationSplitTests(PySparkTestCase):
self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0")
def test_fit_maximize_metric(self):
- sqlContext = SQLContext(self.sc)
- dataset = sqlContext.createDataFrame([
+ dataset = self.spark.createDataFrame([
(10, 10.0),
(50, 50.0),
(100, 100.0),
@@ -630,8 +632,7 @@ class TrainValidationSplitTests(PySparkTestCase):
# This tests saving and loading the trained model only.
# Save/load for TrainValidationSplit will be added later: SPARK-13786
temp_path = tempfile.mkdtemp()
- sqlContext = SQLContext(self.sc)
- dataset = sqlContext.createDataFrame(
+ dataset = self.spark.createDataFrame(
[(Vectors.dense([0.0]), 0.0),
(Vectors.dense([0.4]), 1.0),
(Vectors.dense([0.5]), 0.0),
@@ -652,7 +653,7 @@ class TrainValidationSplitTests(PySparkTestCase):
self.assertEqual(loadedLrModel.intercept, lrModel.intercept)
-class PersistenceTest(PySparkTestCase):
+class PersistenceTest(SparkSessionTestCase):
def test_linear_regression(self):
lr = LinearRegression(maxIter=1)
@@ -724,11 +725,10 @@ class PersistenceTest(PySparkTestCase):
"""
Pipeline[HashingTF, PCA]
"""
- sqlContext = SQLContext(self.sc)
temp_path = tempfile.mkdtemp()
try:
- df = sqlContext.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"])
+ df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"])
tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features")
pca = PCA(k=2, inputCol="features", outputCol="pca_features")
pl = Pipeline(stages=[tf, pca])
@@ -753,11 +753,10 @@ class PersistenceTest(PySparkTestCase):
"""
Pipeline[HashingTF, Pipeline[PCA]]
"""
- sqlContext = SQLContext(self.sc)
temp_path = tempfile.mkdtemp()
try:
- df = sqlContext.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"])
+ df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"])
tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features")
pca = PCA(k=2, inputCol="features", outputCol="pca_features")
p0 = Pipeline(stages=[pca])
@@ -816,7 +815,7 @@ class PersistenceTest(PySparkTestCase):
pass
-class LDATest(PySparkTestCase):
+class LDATest(SparkSessionTestCase):
def _compare(self, m1, m2):
"""
@@ -836,8 +835,7 @@ class LDATest(PySparkTestCase):
def test_persistence(self):
# Test save/load for LDA, LocalLDAModel, DistributedLDAModel.
- sqlContext = SQLContext(self.sc)
- df = sqlContext.createDataFrame([
+ df = self.spark.createDataFrame([
[1, Vectors.dense([0.0, 1.0])],
[2, Vectors.sparse(2, {0: 1.0})],
], ["id", "features"])
@@ -871,12 +869,11 @@ class LDATest(PySparkTestCase):
pass
-class TrainingSummaryTest(PySparkTestCase):
+class TrainingSummaryTest(SparkSessionTestCase):
def test_linear_regression_summary(self):
from pyspark.mllib.linalg import Vectors
- sqlContext = SQLContext(self.sc)
- df = sqlContext.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+ df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
(0.0, 2.0, Vectors.sparse(1, [], []))],
["label", "weight", "features"])
lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight",
@@ -914,8 +911,7 @@ class TrainingSummaryTest(PySparkTestCase):
def test_logistic_regression_summary(self):
from pyspark.mllib.linalg import Vectors
- sqlContext = SQLContext(self.sc)
- df = sqlContext.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+ df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
(0.0, 2.0, Vectors.sparse(1, [], []))],
["label", "weight", "features"])
lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False)
@@ -942,11 +938,10 @@ class TrainingSummaryTest(PySparkTestCase):
self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)
-class OneVsRestTests(PySparkTestCase):
+class OneVsRestTests(SparkSessionTestCase):
def test_copy(self):
- sqlContext = SQLContext(self.sc)
- df = sqlContext.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
+ df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
(1.0, Vectors.sparse(2, [], [])),
(2.0, Vectors.dense(0.5, 0.5))],
["label", "features"])
@@ -960,8 +955,7 @@ class OneVsRestTests(PySparkTestCase):
self.assertEqual(model1.getPredictionCol(), "indexed")
def test_output_columns(self):
- sqlContext = SQLContext(self.sc)
- df = sqlContext.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
+ df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
(1.0, Vectors.sparse(2, [], [])),
(2.0, Vectors.dense(0.5, 0.5))],
["label", "features"])
@@ -973,8 +967,7 @@ class OneVsRestTests(PySparkTestCase):
def test_save_load(self):
temp_path = tempfile.mkdtemp()
- sqlContext = SQLContext(self.sc)
- df = sqlContext.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
+ df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
(1.0, Vectors.sparse(2, [], [])),
(2.0, Vectors.dense(0.5, 0.5))],
["label", "features"])
@@ -994,12 +987,11 @@ class OneVsRestTests(PySparkTestCase):
self.assertEqual(m.uid, n.uid)
-class HashingTFTest(PySparkTestCase):
+class HashingTFTest(SparkSessionTestCase):
def test_apply_binary_term_freqs(self):
- sqlContext = SQLContext(self.sc)
- df = sqlContext.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"])
+ df = self.spark.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"])
n = 10
hashingTF = HashingTF()
hashingTF.setInputCol("words").setOutputCol("features").setNumFeatures(n).setBinary(True)
@@ -1011,11 +1003,10 @@ class HashingTFTest(PySparkTestCase):
": expected " + str(expected[i]) + ", got " + str(features[i]))
-class ALSTest(PySparkTestCase):
+class ALSTest(SparkSessionTestCase):
def test_storage_levels(self):
- sqlContext = SQLContext(self.sc)
- df = sqlContext.createDataFrame(
+ df = self.spark.createDataFrame(
[(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)],
["user", "item", "rating"])
als = ALS().setMaxIter(1).setRank(1)
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 53a1d2c59c..74cf7bb8ea 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -66,7 +66,8 @@ from pyspark.mllib.util import LinearDataGenerator
from pyspark.mllib.util import MLUtils
from pyspark.serializers import PickleSerializer
from pyspark.streaming import StreamingContext
-from pyspark.sql import SQLContext
+from pyspark.sql import SparkSession
+from pyspark.sql.utils import IllegalArgumentException
from pyspark.streaming import StreamingContext
_have_scipy = False
@@ -83,9 +84,10 @@ ser = PickleSerializer()
class MLlibTestCase(unittest.TestCase):
def setUp(self):
self.sc = SparkContext('local[4]', "MLlib tests")
+ self.spark = SparkSession(self.sc)
def tearDown(self):
- self.sc.stop()
+ self.spark.stop()
class MLLibStreamingTestCase(unittest.TestCase):
@@ -698,7 +700,6 @@ class VectorUDTTests(MLlibTestCase):
self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v)))
def test_infer_schema(self):
- sqlCtx = SQLContext(self.sc)
rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)])
df = rdd.toDF()
schema = df.schema
@@ -731,7 +732,6 @@ class MatrixUDTTests(MLlibTestCase):
self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m)))
def test_infer_schema(self):
- sqlCtx = SQLContext(self.sc)
rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)])
df = rdd.toDF()
schema = df.schema
@@ -919,7 +919,7 @@ class ChiSqTestTests(MLlibTestCase):
# Negative counts in observed
neg_obs = Vectors.dense([1.0, 2.0, 3.0, -4.0])
- self.assertRaises(Py4JJavaError, Statistics.chiSqTest, neg_obs, expected1)
+ self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_obs, expected1)
# Count = 0.0 in expected but not observed
zero_expected = Vectors.dense([1.0, 0.0, 3.0])
@@ -930,7 +930,8 @@ class ChiSqTestTests(MLlibTestCase):
# 0.0 in expected and observed simultaneously
zero_observed = Vectors.dense([2.0, 0.0, 1.0])
- self.assertRaises(Py4JJavaError, Statistics.chiSqTest, zero_observed, zero_expected)
+ self.assertRaises(
+ IllegalArgumentException, Statistics.chiSqTest, zero_observed, zero_expected)
def test_matrix_independence(self):
data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0]
@@ -944,15 +945,15 @@ class ChiSqTestTests(MLlibTestCase):
# Negative counts
neg_counts = Matrices.dense(2, 2, [4.0, 5.0, 3.0, -3.0])
- self.assertRaises(Py4JJavaError, Statistics.chiSqTest, neg_counts)
+ self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_counts)
# Row sum = 0.0
row_zero = Matrices.dense(2, 2, [0.0, 1.0, 0.0, 2.0])
- self.assertRaises(Py4JJavaError, Statistics.chiSqTest, row_zero)
+ self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, row_zero)
# Column sum = 0.0
col_zero = Matrices.dense(2, 2, [0.0, 0.0, 2.0, 2.0])
- self.assertRaises(Py4JJavaError, Statistics.chiSqTest, col_zero)
+ self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, col_zero)
def test_chi_sq_pearson(self):
data = [
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 7e79df33e8..bd728c97c8 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -47,19 +47,19 @@ def to_str(value):
class DataFrameReader(object):
"""
Interface used to load a :class:`DataFrame` from external storage systems
- (e.g. file systems, key-value stores, etc). Use :func:`SQLContext.read`
+ (e.g. file systems, key-value stores, etc). Use :func:`spark.read`
to access this.
.. versionadded:: 1.4
"""
- def __init__(self, sqlContext):
- self._jreader = sqlContext._ssql_ctx.read()
- self._sqlContext = sqlContext
+ def __init__(self, spark):
+ self._jreader = spark._ssql_ctx.read()
+ self._spark = spark
def _df(self, jdf):
from pyspark.sql.dataframe import DataFrame
- return DataFrame(jdf, self._sqlContext)
+ return DataFrame(jdf, self._spark)
@since(1.4)
def format(self, source):
@@ -67,7 +67,7 @@ class DataFrameReader(object):
:param source: string, name of the data source, e.g. 'json', 'parquet'.
- >>> df = sqlContext.read.format('json').load('python/test_support/sql/people.json')
+ >>> df = spark.read.format('json').load('python/test_support/sql/people.json')
>>> df.dtypes
[('age', 'bigint'), ('name', 'string')]
@@ -87,7 +87,7 @@ class DataFrameReader(object):
"""
if not isinstance(schema, StructType):
raise TypeError("schema should be StructType")
- jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
+ jschema = self._spark._ssql_ctx.parseDataType(schema.json())
self._jreader = self._jreader.schema(jschema)
return self
@@ -115,12 +115,12 @@ class DataFrameReader(object):
:param schema: optional :class:`StructType` for the input schema.
:param options: all other string options
- >>> df = sqlContext.read.load('python/test_support/sql/parquet_partitioned', opt1=True,
+ >>> df = spark.read.load('python/test_support/sql/parquet_partitioned', opt1=True,
... opt2=1, opt3='str')
>>> df.dtypes
[('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
- >>> df = sqlContext.read.format('json').load(['python/test_support/sql/people.json',
+ >>> df = spark.read.format('json').load(['python/test_support/sql/people.json',
... 'python/test_support/sql/people1.json'])
>>> df.dtypes
[('age', 'bigint'), ('aka', 'string'), ('name', 'string')]
@@ -133,7 +133,7 @@ class DataFrameReader(object):
if path is not None:
if type(path) != list:
path = [path]
- return self._df(self._jreader.load(self._sqlContext._sc._jvm.PythonUtils.toSeq(path)))
+ return self._df(self._jreader.load(self._spark._sc._jvm.PythonUtils.toSeq(path)))
else:
return self._df(self._jreader.load())
@@ -148,7 +148,7 @@ class DataFrameReader(object):
:param schema: optional :class:`StructType` for the input schema.
:param options: all other string options
- >>> df = sqlContext.read.format('text').stream('python/test_support/sql/streaming')
+ >>> df = spark.read.format('text').stream('python/test_support/sql/streaming')
>>> df.isStreaming
True
"""
@@ -211,11 +211,11 @@ class DataFrameReader(object):
``spark.sql.columnNameOfCorruptRecord``. If None is set,
it uses the default value ``_corrupt_record``.
- >>> df1 = sqlContext.read.json('python/test_support/sql/people.json')
+ >>> df1 = spark.read.json('python/test_support/sql/people.json')
>>> df1.dtypes
[('age', 'bigint'), ('name', 'string')]
>>> rdd = sc.textFile('python/test_support/sql/people.json')
- >>> df2 = sqlContext.read.json(rdd)
+ >>> df2 = spark.read.json(rdd)
>>> df2.dtypes
[('age', 'bigint'), ('name', 'string')]
@@ -243,7 +243,7 @@ class DataFrameReader(object):
if isinstance(path, basestring):
path = [path]
if type(path) == list:
- return self._df(self._jreader.json(self._sqlContext._sc._jvm.PythonUtils.toSeq(path)))
+ return self._df(self._jreader.json(self._spark._sc._jvm.PythonUtils.toSeq(path)))
elif isinstance(path, RDD):
def func(iterator):
for x in iterator:
@@ -254,7 +254,7 @@ class DataFrameReader(object):
yield x
keyed = path.mapPartitions(func)
keyed._bypass_serializer = True
- jrdd = keyed._jrdd.map(self._sqlContext._jvm.BytesToString())
+ jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString())
return self._df(self._jreader.json(jrdd))
else:
raise TypeError("path can be only string or RDD")
@@ -265,9 +265,9 @@ class DataFrameReader(object):
:param tableName: string, name of the table.
- >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned')
+ >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned')
>>> df.registerTempTable('tmpTable')
- >>> sqlContext.read.table('tmpTable').dtypes
+ >>> spark.read.table('tmpTable').dtypes
[('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
"""
return self._df(self._jreader.table(tableName))
@@ -276,11 +276,11 @@ class DataFrameReader(object):
def parquet(self, *paths):
"""Loads a Parquet file, returning the result as a :class:`DataFrame`.
- >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned')
+ >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned')
>>> df.dtypes
[('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
"""
- return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, paths)))
+ return self._df(self._jreader.parquet(_to_seq(self._spark._sc, paths)))
@ignore_unicode_prefix
@since(1.6)
@@ -291,13 +291,13 @@ class DataFrameReader(object):
:param paths: string, or list of strings, for input path(s).
- >>> df = sqlContext.read.text('python/test_support/sql/text-test.txt')
+ >>> df = spark.read.text('python/test_support/sql/text-test.txt')
>>> df.collect()
[Row(value=u'hello'), Row(value=u'this')]
"""
if isinstance(paths, basestring):
path = [paths]
- return self._df(self._jreader.text(self._sqlContext._sc._jvm.PythonUtils.toSeq(path)))
+ return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(path)))
@since(2.0)
def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None,
@@ -356,7 +356,7 @@ class DataFrameReader(object):
* ``DROPMALFORMED`` : ignores the whole corrupted records.
* ``FAILFAST`` : throws an exception when it meets corrupted records.
- >>> df = sqlContext.read.csv('python/test_support/sql/ages.csv')
+ >>> df = spark.read.csv('python/test_support/sql/ages.csv')
>>> df.dtypes
[('C0', 'string'), ('C1', 'string')]
"""
@@ -396,7 +396,7 @@ class DataFrameReader(object):
self.option("mode", mode)
if isinstance(path, basestring):
path = [path]
- return self._df(self._jreader.csv(self._sqlContext._sc._jvm.PythonUtils.toSeq(path)))
+ return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path)))
@since(1.5)
def orc(self, path):
@@ -441,16 +441,16 @@ class DataFrameReader(object):
"""
if properties is None:
properties = dict()
- jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)()
+ jprop = JavaClass("java.util.Properties", self._spark._sc._gateway._gateway_client)()
for k in properties:
jprop.setProperty(k, properties[k])
if column is not None:
if numPartitions is None:
- numPartitions = self._sqlContext._sc.defaultParallelism
+ numPartitions = self._spark._sc.defaultParallelism
return self._df(self._jreader.jdbc(url, table, column, int(lowerBound), int(upperBound),
int(numPartitions), jprop))
if predicates is not None:
- gateway = self._sqlContext._sc._gateway
+ gateway = self._spark._sc._gateway
jpredicates = utils.toJArray(gateway, gateway.jvm.java.lang.String, predicates)
return self._df(self._jreader.jdbc(url, table, jpredicates, jprop))
return self._df(self._jreader.jdbc(url, table, jprop))
@@ -466,7 +466,7 @@ class DataFrameWriter(object):
"""
def __init__(self, df):
self._df = df
- self._sqlContext = df.sql_ctx
+ self._spark = df.sql_ctx
self._jwrite = df._jdf.write()
def _cq(self, jcq):
@@ -531,14 +531,14 @@ class DataFrameWriter(object):
"""
if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
cols = cols[0]
- self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols))
+ self._jwrite = self._jwrite.partitionBy(_to_seq(self._spark._sc, cols))
return self
@since(2.0)
def queryName(self, queryName):
"""Specifies the name of the :class:`ContinuousQuery` that can be started with
:func:`startStream`. This name must be unique among all the currently active queries
- in the associated SQLContext
+ in the associated spark
.. note:: Experimental.
@@ -573,7 +573,7 @@ class DataFrameWriter(object):
trigger = ProcessingTime(processingTime)
if trigger is None:
raise ValueError('A trigger was not provided. Supported triggers: processingTime.')
- self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._sqlContext))
+ self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._spark))
return self
@since(1.4)
@@ -854,7 +854,7 @@ class DataFrameWriter(object):
"""
if properties is None:
properties = dict()
- jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)()
+ jprop = JavaClass("java.util.Properties", self._spark._sc._gateway._gateway_client)()
for k in properties:
jprop.setProperty(k, properties[k])
self._jwrite.mode(mode).jdbc(url, table, jprop)
@@ -865,7 +865,7 @@ def _test():
import os
import tempfile
from pyspark.context import SparkContext
- from pyspark.sql import Row, SQLContext, HiveContext
+ from pyspark.sql import SparkSession, Row, HiveContext
import pyspark.sql.readwriter
os.chdir(os.environ["SPARK_HOME"])
@@ -876,11 +876,13 @@ def _test():
globs['tempfile'] = tempfile
globs['os'] = os
globs['sc'] = sc
- globs['sqlContext'] = SQLContext(sc)
+ globs['spark'] = SparkSession.builder\
+ .enableHiveSupport()\
+ .getOrCreate()
globs['hiveContext'] = HiveContext._createForTesting(sc)
- globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned')
+ globs['df'] = globs['spark'].read.parquet('python/test_support/sql/parquet_partitioned')
globs['sdf'] = \
- globs['sqlContext'].read.format('text').stream('python/test_support/sql/streaming')
+ globs['spark'].read.format('text').stream('python/test_support/sql/streaming')
(failure_count, test_count) = doctest.testmod(
pyspark.sql.readwriter, globs=globs,
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index cd5c4a7b3e..0c73f58c3b 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -45,7 +45,7 @@ if sys.version_info[:2] <= (2, 6):
else:
import unittest
-from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
+from pyspark.sql import SparkSession, HiveContext, Column, Row
from pyspark.sql.types import *
from pyspark.sql.types import UserDefinedType, _infer_type
from pyspark.tests import ReusedPySparkTestCase
@@ -178,20 +178,6 @@ class DataTypeTests(unittest.TestCase):
self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1))
-class SQLContextTests(ReusedPySparkTestCase):
- def test_get_or_create(self):
- sqlCtx = SQLContext.getOrCreate(self.sc)
- self.assertTrue(SQLContext.getOrCreate(self.sc) is sqlCtx)
-
- def test_new_session(self):
- sqlCtx = SQLContext.getOrCreate(self.sc)
- sqlCtx.setConf("test_key", "a")
- sqlCtx2 = sqlCtx.newSession()
- sqlCtx2.setConf("test_key", "b")
- self.assertEqual(sqlCtx.getConf("test_key", ""), "a")
- self.assertEqual(sqlCtx2.getConf("test_key", ""), "b")
-
-
class SQLTests(ReusedPySparkTestCase):
@classmethod
@@ -199,15 +185,14 @@ class SQLTests(ReusedPySparkTestCase):
ReusedPySparkTestCase.setUpClass()
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(cls.tempdir.name)
- cls.sparkSession = SparkSession(cls.sc)
- cls.sqlCtx = cls.sparkSession._wrapped
+ cls.spark = SparkSession(cls.sc)
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
- rdd = cls.sc.parallelize(cls.testData, 2)
- cls.df = rdd.toDF()
+ cls.df = cls.spark.createDataFrame(cls.testData)
@classmethod
def tearDownClass(cls):
ReusedPySparkTestCase.tearDownClass()
+ cls.spark.stop()
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
def test_row_should_be_read_only(self):
@@ -218,7 +203,7 @@ class SQLTests(ReusedPySparkTestCase):
row.a = 3
self.assertRaises(Exception, foo)
- row2 = self.sqlCtx.range(10).first()
+ row2 = self.spark.range(10).first()
self.assertEqual(0, row2.id)
def foo2():
@@ -226,14 +211,14 @@ class SQLTests(ReusedPySparkTestCase):
self.assertRaises(Exception, foo2)
def test_range(self):
- self.assertEqual(self.sqlCtx.range(1, 1).count(), 0)
- self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1)
- self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2)
- self.assertEqual(self.sqlCtx.range(-2).count(), 0)
- self.assertEqual(self.sqlCtx.range(3).count(), 3)
+ self.assertEqual(self.spark.range(1, 1).count(), 0)
+ self.assertEqual(self.spark.range(1, 0, -1).count(), 1)
+ self.assertEqual(self.spark.range(0, 1 << 40, 1 << 39).count(), 2)
+ self.assertEqual(self.spark.range(-2).count(), 0)
+ self.assertEqual(self.spark.range(3).count(), 3)
def test_duplicated_column_names(self):
- df = self.sqlCtx.createDataFrame([(1, 2)], ["c", "c"])
+ df = self.spark.createDataFrame([(1, 2)], ["c", "c"])
row = df.select('*').first()
self.assertEqual(1, row[0])
self.assertEqual(2, row[1])
@@ -247,7 +232,7 @@ class SQLTests(ReusedPySparkTestCase):
from pyspark.sql.functions import explode
d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})]
rdd = self.sc.parallelize(d)
- data = self.sqlCtx.createDataFrame(rdd)
+ data = self.spark.createDataFrame(rdd)
result = data.select(explode(data.intlist).alias("a")).select("a").collect()
self.assertEqual(result[0][0], 1)
@@ -269,7 +254,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_udf_with_callable(self):
d = [Row(number=i, squared=i**2) for i in range(10)]
rdd = self.sc.parallelize(d)
- data = self.sqlCtx.createDataFrame(rdd)
+ data = self.spark.createDataFrame(rdd)
class PlusFour:
def __call__(self, col):
@@ -284,7 +269,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_udf_with_partial_function(self):
d = [Row(number=i, squared=i**2) for i in range(10)]
rdd = self.sc.parallelize(d)
- data = self.sqlCtx.createDataFrame(rdd)
+ data = self.spark.createDataFrame(rdd)
def some_func(col, param):
if col is not None:
@@ -296,56 +281,56 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
def test_udf(self):
- self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
- [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
+ self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
+ [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
self.assertEqual(row[0], 5)
def test_udf2(self):
- self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType())
- self.sqlCtx.createDataFrame(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
- [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
+ self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType())
+ self.spark.createDataFrame(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
+ [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
self.assertEqual(4, res[0])
def test_chained_udf(self):
- self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType())
- [row] = self.sqlCtx.sql("SELECT double(1)").collect()
+ self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType())
+ [row] = self.spark.sql("SELECT double(1)").collect()
self.assertEqual(row[0], 2)
- [row] = self.sqlCtx.sql("SELECT double(double(1))").collect()
+ [row] = self.spark.sql("SELECT double(double(1))").collect()
self.assertEqual(row[0], 4)
- [row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect()
+ [row] = self.spark.sql("SELECT double(double(1) + 1)").collect()
self.assertEqual(row[0], 6)
def test_multiple_udfs(self):
- self.sqlCtx.registerFunction("double", lambda x: x * 2, IntegerType())
- [row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect()
+ self.spark.catalog.registerFunction("double", lambda x: x * 2, IntegerType())
+ [row] = self.spark.sql("SELECT double(1), double(2)").collect()
self.assertEqual(tuple(row), (2, 4))
- [row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect()
+ [row] = self.spark.sql("SELECT double(double(1)), double(double(2) + 2)").collect()
self.assertEqual(tuple(row), (4, 12))
- self.sqlCtx.registerFunction("add", lambda x, y: x + y, IntegerType())
- [row] = self.sqlCtx.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect()
+ self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType())
+ [row] = self.spark.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect()
self.assertEqual(tuple(row), (6, 5))
def test_udf_with_array_type(self):
d = [Row(l=list(range(3)), d={"key": list(range(5))})]
rdd = self.sc.parallelize(d)
- self.sqlCtx.createDataFrame(rdd).registerTempTable("test")
- self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
- self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
- [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
+ self.spark.createDataFrame(rdd).registerTempTable("test")
+ self.spark.catalog.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
+ self.spark.catalog.registerFunction("maplen", lambda d: len(d), IntegerType())
+ [(l1, l2)] = self.spark.sql("select copylist(l), maplen(d) from test").collect()
self.assertEqual(list(range(3)), l1)
self.assertEqual(1, l2)
def test_broadcast_in_udf(self):
bar = {"a": "aa", "b": "bb", "c": "abc"}
foo = self.sc.broadcast(bar)
- self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
- [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect()
+ self.spark.catalog.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
+ [res] = self.spark.sql("SELECT MYUDF('c')").collect()
self.assertEqual("abc", res[0])
- [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
+ [res] = self.spark.sql("SELECT MYUDF('')").collect()
self.assertEqual("", res[0])
def test_udf_with_aggregate_function(self):
- df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
+ df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
from pyspark.sql.functions import udf, col
from pyspark.sql.types import BooleanType
@@ -355,7 +340,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
- df = self.sqlCtx.read.json(rdd)
+ df = self.spark.read.json(rdd)
df.count()
df.collect()
df.schema
@@ -369,41 +354,41 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(2, df.count())
df.registerTempTable("temp")
- df = self.sqlCtx.sql("select foo from temp")
+ df = self.spark.sql("select foo from temp")
df.count()
df.collect()
def test_apply_schema_to_row(self):
- df = self.sqlCtx.read.json(self.sc.parallelize(["""{"a":2}"""]))
- df2 = self.sqlCtx.createDataFrame(df.rdd.map(lambda x: x), df.schema)
+ df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""]))
+ df2 = self.spark.createDataFrame(df.rdd.map(lambda x: x), df.schema)
self.assertEqual(df.collect(), df2.collect())
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
- df3 = self.sqlCtx.createDataFrame(rdd, df.schema)
+ df3 = self.spark.createDataFrame(rdd, df.schema)
self.assertEqual(10, df3.count())
def test_infer_schema_to_local(self):
input = [{"a": 1}, {"b": "coffee"}]
rdd = self.sc.parallelize(input)
- df = self.sqlCtx.createDataFrame(input)
- df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0)
+ df = self.spark.createDataFrame(input)
+ df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0)
self.assertEqual(df.schema, df2.schema)
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
- df3 = self.sqlCtx.createDataFrame(rdd, df.schema)
+ df3 = self.spark.createDataFrame(rdd, df.schema)
self.assertEqual(10, df3.count())
def test_create_dataframe_schema_mismatch(self):
input = [Row(a=1)]
rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i))
schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())])
- df = self.sqlCtx.createDataFrame(rdd, schema)
+ df = self.spark.createDataFrame(rdd, schema)
self.assertRaises(Exception, lambda: df.show())
def test_serialize_nested_array_and_map(self):
d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
rdd = self.sc.parallelize(d)
- df = self.sqlCtx.createDataFrame(rdd)
+ df = self.spark.createDataFrame(rdd)
row = df.head()
self.assertEqual(1, len(row.l))
self.assertEqual(1, row.l[0].a)
@@ -425,31 +410,31 @@ class SQLTests(ReusedPySparkTestCase):
d = [Row(l=[], d={}, s=None),
Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
rdd = self.sc.parallelize(d)
- df = self.sqlCtx.createDataFrame(rdd)
+ df = self.spark.createDataFrame(rdd)
self.assertEqual([], df.rdd.map(lambda r: r.l).first())
self.assertEqual([None, ""], df.rdd.map(lambda r: r.s).collect())
df.registerTempTable("test")
- result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
+ result = self.spark.sql("SELECT l[0].a from test where d['key'].d = '2'")
self.assertEqual(1, result.head()[0])
- df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0)
+ df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0)
self.assertEqual(df.schema, df2.schema)
self.assertEqual({}, df2.rdd.map(lambda r: r.d).first())
self.assertEqual([None, ""], df2.rdd.map(lambda r: r.s).collect())
df2.registerTempTable("test2")
- result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
+ result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
self.assertEqual(1, result.head()[0])
def test_infer_nested_schema(self):
NestedRow = Row("f1", "f2")
nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}),
NestedRow([2, 3], {"row2": 2.0})])
- df = self.sqlCtx.createDataFrame(nestedRdd1)
+ df = self.spark.createDataFrame(nestedRdd1)
self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0])
nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]),
NestedRow([[2, 3], [3, 4]], [2, 3])])
- df = self.sqlCtx.createDataFrame(nestedRdd2)
+ df = self.spark.createDataFrame(nestedRdd2)
self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0])
from collections import namedtuple
@@ -457,17 +442,17 @@ class SQLTests(ReusedPySparkTestCase):
rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"),
CustomRow(field1=2, field2="row2"),
CustomRow(field1=3, field2="row3")])
- df = self.sqlCtx.createDataFrame(rdd)
+ df = self.spark.createDataFrame(rdd)
self.assertEqual(Row(field1=1, field2=u'row1'), df.first())
def test_create_dataframe_from_objects(self):
data = [MyObject(1, "1"), MyObject(2, "2")]
- df = self.sqlCtx.createDataFrame(data)
+ df = self.spark.createDataFrame(data)
self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")])
self.assertEqual(df.first(), Row(key=1, value="1"))
def test_select_null_literal(self):
- df = self.sqlCtx.sql("select null as col")
+ df = self.spark.sql("select null as col")
self.assertEqual(Row(col=None), df.first())
def test_apply_schema(self):
@@ -488,7 +473,7 @@ class SQLTests(ReusedPySparkTestCase):
StructField("struct1", StructType([StructField("b", ShortType(), False)]), False),
StructField("list1", ArrayType(ByteType(), False), False),
StructField("null1", DoubleType(), True)])
- df = self.sqlCtx.createDataFrame(rdd, schema)
+ df = self.spark.createDataFrame(rdd, schema)
results = df.rdd.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1,
x.date1, x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1),
@@ -496,9 +481,9 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(r, results.first())
df.registerTempTable("table2")
- r = self.sqlCtx.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
- "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " +
- "float1 + 1.5 as float1 FROM table2").first()
+ r = self.spark.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
+ "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " +
+ "float1 + 1.5 as float1 FROM table2").first()
self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r))
@@ -508,7 +493,7 @@ class SQLTests(ReusedPySparkTestCase):
abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]"
schema = _parse_schema_abstract(abstract)
typedSchema = _infer_schema_type(rdd.first(), schema)
- df = self.sqlCtx.createDataFrame(rdd, typedSchema)
+ df = self.spark.createDataFrame(rdd, typedSchema)
r = (127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), {"a": 1}, Row(b=2), [1, 2, 3])
self.assertEqual(r, tuple(df.first()))
@@ -524,7 +509,7 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(1, row.asDict()['l'][0].a)
df = self.sc.parallelize([row]).toDF()
df.registerTempTable("test")
- row = self.sqlCtx.sql("select l, d from test").head()
+ row = self.spark.sql("select l, d from test").head()
self.assertEqual(1, row.asDict()["l"][0].a)
self.assertEqual(1.0, row.asDict()['d']['key'].c)
@@ -535,7 +520,7 @@ class SQLTests(ReusedPySparkTestCase):
def check_datatype(datatype):
pickled = pickle.loads(pickle.dumps(datatype))
assert datatype == pickled
- scala_datatype = self.sqlCtx._ssql_ctx.parseDataType(datatype.json())
+ scala_datatype = self.spark._wrapped._ssql_ctx.parseDataType(datatype.json())
python_datatype = _parse_datatype_json_string(scala_datatype.json())
assert datatype == python_datatype
@@ -560,21 +545,21 @@ class SQLTests(ReusedPySparkTestCase):
def test_infer_schema_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
- df = self.sqlCtx.createDataFrame([row])
+ df = self.spark.createDataFrame([row])
schema = df.schema
field = [f for f in schema.fields if f.name == "point"][0]
self.assertEqual(type(field.dataType), ExamplePointUDT)
df.registerTempTable("labeled_point")
- point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
+ point = self.spark.sql("SELECT point FROM labeled_point").head().point
self.assertEqual(point, ExamplePoint(1.0, 2.0))
row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
- df = self.sqlCtx.createDataFrame([row])
+ df = self.spark.createDataFrame([row])
schema = df.schema
field = [f for f in schema.fields if f.name == "point"][0]
self.assertEqual(type(field.dataType), PythonOnlyUDT)
df.registerTempTable("labeled_point")
- point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
+ point = self.spark.sql("SELECT point FROM labeled_point").head().point
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
def test_apply_schema_with_udt(self):
@@ -582,21 +567,21 @@ class SQLTests(ReusedPySparkTestCase):
row = (1.0, ExamplePoint(1.0, 2.0))
schema = StructType([StructField("label", DoubleType(), False),
StructField("point", ExamplePointUDT(), False)])
- df = self.sqlCtx.createDataFrame([row], schema)
+ df = self.spark.createDataFrame([row], schema)
point = df.head().point
self.assertEqual(point, ExamplePoint(1.0, 2.0))
row = (1.0, PythonOnlyPoint(1.0, 2.0))
schema = StructType([StructField("label", DoubleType(), False),
StructField("point", PythonOnlyUDT(), False)])
- df = self.sqlCtx.createDataFrame([row], schema)
+ df = self.spark.createDataFrame([row], schema)
point = df.head().point
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
def test_udf_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
- df = self.sqlCtx.createDataFrame([row])
+ df = self.spark.createDataFrame([row])
self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
udf = UserDefinedFunction(lambda p: p.y, DoubleType())
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
@@ -604,7 +589,7 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
- df = self.sqlCtx.createDataFrame([row])
+ df = self.spark.createDataFrame([row])
self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
udf = UserDefinedFunction(lambda p: p.y, DoubleType())
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
@@ -614,17 +599,17 @@ class SQLTests(ReusedPySparkTestCase):
def test_parquet_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
- df0 = self.sqlCtx.createDataFrame([row])
+ df0 = self.spark.createDataFrame([row])
output_dir = os.path.join(self.tempdir.name, "labeled_point")
df0.write.parquet(output_dir)
- df1 = self.sqlCtx.read.parquet(output_dir)
+ df1 = self.spark.read.parquet(output_dir)
point = df1.head().point
self.assertEqual(point, ExamplePoint(1.0, 2.0))
row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
- df0 = self.sqlCtx.createDataFrame([row])
+ df0 = self.spark.createDataFrame([row])
df0.write.parquet(output_dir, mode='overwrite')
- df1 = self.sqlCtx.read.parquet(output_dir)
+ df1 = self.spark.read.parquet(output_dir)
point = df1.head().point
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
@@ -634,8 +619,8 @@ class SQLTests(ReusedPySparkTestCase):
row2 = (2.0, ExamplePoint(3.0, 4.0))
schema = StructType([StructField("label", DoubleType(), False),
StructField("point", ExamplePointUDT(), False)])
- df1 = self.sqlCtx.createDataFrame([row1], schema)
- df2 = self.sqlCtx.createDataFrame([row2], schema)
+ df1 = self.spark.createDataFrame([row1], schema)
+ df2 = self.spark.createDataFrame([row2], schema)
result = df1.union(df2).orderBy("label").collect()
self.assertEqual(
@@ -688,7 +673,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_first_last_ignorenulls(self):
from pyspark.sql import functions
- df = self.sqlCtx.range(0, 100)
+ df = self.spark.range(0, 100)
df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id"))
df3 = df2.select(functions.first(df2.id, False).alias('a'),
functions.first(df2.id, True).alias('b'),
@@ -829,36 +814,36 @@ class SQLTests(ReusedPySparkTestCase):
schema = StructType([StructField("f1", StringType(), True, None),
StructField("f2", StringType(), True, {'a': None})])
rdd = self.sc.parallelize([["a", "b"], ["c", "d"]])
- self.sqlCtx.createDataFrame(rdd, schema)
+ self.spark.createDataFrame(rdd, schema)
def test_save_and_load(self):
df = self.df
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
df.write.json(tmpPath)
- actual = self.sqlCtx.read.json(tmpPath)
+ actual = self.spark.read.json(tmpPath)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
schema = StructType([StructField("value", StringType(), True)])
- actual = self.sqlCtx.read.json(tmpPath, schema)
+ actual = self.spark.read.json(tmpPath, schema)
self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
df.write.json(tmpPath, "overwrite")
- actual = self.sqlCtx.read.json(tmpPath)
+ actual = self.spark.read.json(tmpPath)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
df.write.save(format="json", mode="overwrite", path=tmpPath,
noUse="this options will not be used in save.")
- actual = self.sqlCtx.read.load(format="json", path=tmpPath,
- noUse="this options will not be used in load.")
+ actual = self.spark.read.load(format="json", path=tmpPath,
+ noUse="this options will not be used in load.")
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
- defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
+ defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
- self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
- actual = self.sqlCtx.read.load(path=tmpPath)
+ self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
+ actual = self.spark.read.load(path=tmpPath)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
- self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
+ self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
csvpath = os.path.join(tempfile.mkdtemp(), 'data')
df.write.option('quote', None).format('csv').save(csvpath)
@@ -870,36 +855,36 @@ class SQLTests(ReusedPySparkTestCase):
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
df.write.json(tmpPath)
- actual = self.sqlCtx.read.json(tmpPath)
+ actual = self.spark.read.json(tmpPath)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
schema = StructType([StructField("value", StringType(), True)])
- actual = self.sqlCtx.read.json(tmpPath, schema)
+ actual = self.spark.read.json(tmpPath, schema)
self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
df.write.mode("overwrite").json(tmpPath)
- actual = self.sqlCtx.read.json(tmpPath)
+ actual = self.spark.read.json(tmpPath)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
df.write.mode("overwrite").options(noUse="this options will not be used in save.")\
.option("noUse", "this option will not be used in save.")\
.format("json").save(path=tmpPath)
actual =\
- self.sqlCtx.read.format("json")\
- .load(path=tmpPath, noUse="this options will not be used in load.")
+ self.spark.read.format("json")\
+ .load(path=tmpPath, noUse="this options will not be used in load.")
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
- defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
+ defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
- self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
- actual = self.sqlCtx.read.load(path=tmpPath)
+ self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
+ actual = self.spark.read.load(path=tmpPath)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
- self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
+ self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
shutil.rmtree(tmpPath)
def test_stream_trigger_takes_keyword_args(self):
- df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming')
+ df = self.spark.read.format('text').stream('python/test_support/sql/streaming')
try:
df.write.trigger('5 seconds')
self.fail("Should have thrown an exception")
@@ -909,7 +894,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_stream_read_options(self):
schema = StructType([StructField("data", StringType(), False)])
- df = self.sqlCtx.read.format('text').option('path', 'python/test_support/sql/streaming')\
+ df = self.spark.read.format('text').option('path', 'python/test_support/sql/streaming')\
.schema(schema).stream()
self.assertTrue(df.isStreaming)
self.assertEqual(df.schema.simpleString(), "struct<data:string>")
@@ -917,15 +902,15 @@ class SQLTests(ReusedPySparkTestCase):
def test_stream_read_options_overwrite(self):
bad_schema = StructType([StructField("test", IntegerType(), False)])
schema = StructType([StructField("data", StringType(), False)])
- df = self.sqlCtx.read.format('csv').option('path', 'python/test_support/sql/fake') \
+ df = self.spark.read.format('csv').option('path', 'python/test_support/sql/fake') \
.schema(bad_schema).stream(path='python/test_support/sql/streaming',
schema=schema, format='text')
self.assertTrue(df.isStreaming)
self.assertEqual(df.schema.simpleString(), "struct<data:string>")
def test_stream_save_options(self):
- df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming')
- for cq in self.sqlCtx.streams.active:
+ df = self.spark.read.format('text').stream('python/test_support/sql/streaming')
+ for cq in self.spark._wrapped.streams.active:
cq.stop()
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
@@ -948,8 +933,8 @@ class SQLTests(ReusedPySparkTestCase):
shutil.rmtree(tmpPath)
def test_stream_save_options_overwrite(self):
- df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming')
- for cq in self.sqlCtx.streams.active:
+ df = self.spark.read.format('text').stream('python/test_support/sql/streaming')
+ for cq in self.spark._wrapped.streams.active:
cq.stop()
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
@@ -977,8 +962,8 @@ class SQLTests(ReusedPySparkTestCase):
shutil.rmtree(tmpPath)
def test_stream_await_termination(self):
- df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming')
- for cq in self.sqlCtx.streams.active:
+ df = self.spark.read.format('text').stream('python/test_support/sql/streaming')
+ for cq in self.spark._wrapped.streams.active:
cq.stop()
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
@@ -1005,8 +990,8 @@ class SQLTests(ReusedPySparkTestCase):
shutil.rmtree(tmpPath)
def test_query_manager_await_termination(self):
- df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming')
- for cq in self.sqlCtx.streams.active:
+ df = self.spark.read.format('text').stream('python/test_support/sql/streaming')
+ for cq in self.spark._wrapped.streams.active:
cq.stop()
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
@@ -1018,13 +1003,13 @@ class SQLTests(ReusedPySparkTestCase):
try:
self.assertTrue(cq.isActive)
try:
- self.sqlCtx.streams.awaitAnyTermination("hello")
+ self.spark._wrapped.streams.awaitAnyTermination("hello")
self.fail("Expected a value exception")
except ValueError:
pass
now = time.time()
# test should take at least 2 seconds
- res = self.sqlCtx.streams.awaitAnyTermination(2.6)
+ res = self.spark._wrapped.streams.awaitAnyTermination(2.6)
duration = time.time() - now
self.assertTrue(duration >= 2)
self.assertFalse(res)
@@ -1035,7 +1020,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_help_command(self):
# Regression test for SPARK-5464
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
- df = self.sqlCtx.read.json(rdd)
+ df = self.spark.read.json(rdd)
# render_doc() reproduces the help() exception without printing output
pydoc.render_doc(df)
pydoc.render_doc(df.foo)
@@ -1051,7 +1036,7 @@ class SQLTests(ReusedPySparkTestCase):
self.assertRaises(TypeError, lambda: df[{}])
def test_column_name_with_non_ascii(self):
- df = self.sqlCtx.createDataFrame([(1,)], ["数量"])
+ df = self.spark.createDataFrame([(1,)], ["数量"])
self.assertEqual(StructType([StructField("数量", LongType(), True)]), df.schema)
self.assertEqual("DataFrame[数量: bigint]", str(df))
self.assertEqual([("数量", 'bigint')], df.dtypes)
@@ -1084,7 +1069,7 @@ class SQLTests(ReusedPySparkTestCase):
# this saving as Parquet caused issues as well.
output_dir = os.path.join(self.tempdir.name, "infer_long_type")
df.write.parquet(output_dir)
- df1 = self.sqlCtx.read.parquet(output_dir)
+ df1 = self.spark.read.parquet(output_dir)
self.assertEqual('a', df1.first().f1)
self.assertEqual(100000000000000, df1.first().f2)
@@ -1100,7 +1085,7 @@ class SQLTests(ReusedPySparkTestCase):
time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000)
date = time.date()
row = Row(date=date, time=time)
- df = self.sqlCtx.createDataFrame([row])
+ df = self.spark.createDataFrame([row])
self.assertEqual(1, df.filter(df.date == date).count())
self.assertEqual(1, df.filter(df.time == time).count())
self.assertEqual(0, df.filter(df.date > date).count())
@@ -1110,7 +1095,7 @@ class SQLTests(ReusedPySparkTestCase):
dt1 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(0))
dt2 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(1))
row = Row(date=dt1)
- df = self.sqlCtx.createDataFrame([row])
+ df = self.spark.createDataFrame([row])
self.assertEqual(0, df.filter(df.date == dt2).count())
self.assertEqual(1, df.filter(df.date > dt2).count())
self.assertEqual(0, df.filter(df.date < dt2).count())
@@ -1125,7 +1110,7 @@ class SQLTests(ReusedPySparkTestCase):
utcnow = datetime.datetime.utcfromtimestamp(ts) # without microseconds
# add microseconds to utcnow (keeping year,month,day,hour,minute,second)
utcnow = datetime.datetime(*(utcnow.timetuple()[:6] + (now.microsecond, utc)))
- df = self.sqlCtx.createDataFrame([(day, now, utcnow)])
+ df = self.spark.createDataFrame([(day, now, utcnow)])
day1, now1, utcnow1 = df.first()
self.assertEqual(day1, day)
self.assertEqual(now, now1)
@@ -1134,13 +1119,13 @@ class SQLTests(ReusedPySparkTestCase):
def test_decimal(self):
from decimal import Decimal
schema = StructType([StructField("decimal", DecimalType(10, 5))])
- df = self.sqlCtx.createDataFrame([(Decimal("3.14159"),)], schema)
+ df = self.spark.createDataFrame([(Decimal("3.14159"),)], schema)
row = df.select(df.decimal + 1).first()
self.assertEqual(row[0], Decimal("4.14159"))
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
df.write.parquet(tmpPath)
- df2 = self.sqlCtx.read.parquet(tmpPath)
+ df2 = self.spark.read.parquet(tmpPath)
row = df2.first()
self.assertEqual(row[0], Decimal("3.14159"))
@@ -1151,52 +1136,52 @@ class SQLTests(ReusedPySparkTestCase):
StructField("height", DoubleType(), True)])
# shouldn't drop a non-null row
- self.assertEqual(self.sqlCtx.createDataFrame(
+ self.assertEqual(self.spark.createDataFrame(
[(u'Alice', 50, 80.1)], schema).dropna().count(),
1)
# dropping rows with a single null value
- self.assertEqual(self.sqlCtx.createDataFrame(
+ self.assertEqual(self.spark.createDataFrame(
[(u'Alice', None, 80.1)], schema).dropna().count(),
0)
- self.assertEqual(self.sqlCtx.createDataFrame(
+ self.assertEqual(self.spark.createDataFrame(
[(u'Alice', None, 80.1)], schema).dropna(how='any').count(),
0)
# if how = 'all', only drop rows if all values are null
- self.assertEqual(self.sqlCtx.createDataFrame(
+ self.assertEqual(self.spark.createDataFrame(
[(u'Alice', None, 80.1)], schema).dropna(how='all').count(),
1)
- self.assertEqual(self.sqlCtx.createDataFrame(
+ self.assertEqual(self.spark.createDataFrame(
[(None, None, None)], schema).dropna(how='all').count(),
0)
# how and subset
- self.assertEqual(self.sqlCtx.createDataFrame(
+ self.assertEqual(self.spark.createDataFrame(
[(u'Alice', 50, None)], schema).dropna(how='any', subset=['name', 'age']).count(),
1)
- self.assertEqual(self.sqlCtx.createDataFrame(
+ self.assertEqual(self.spark.createDataFrame(
[(u'Alice', None, None)], schema).dropna(how='any', subset=['name', 'age']).count(),
0)
# threshold
- self.assertEqual(self.sqlCtx.createDataFrame(
+ self.assertEqual(self.spark.createDataFrame(
[(u'Alice', None, 80.1)], schema).dropna(thresh=2).count(),
1)
- self.assertEqual(self.sqlCtx.createDataFrame(
+ self.assertEqual(self.spark.createDataFrame(
[(u'Alice', None, None)], schema).dropna(thresh=2).count(),
0)
# threshold and subset
- self.assertEqual(self.sqlCtx.createDataFrame(
+ self.assertEqual(self.spark.createDataFrame(
[(u'Alice', 50, None)], schema).dropna(thresh=2, subset=['name', 'age']).count(),
1)
- self.assertEqual(self.sqlCtx.createDataFrame(
+ self.assertEqual(self.spark.createDataFrame(
[(u'Alice', None, 180.9)], schema).dropna(thresh=2, subset=['name', 'age']).count(),
0)
# thresh should take precedence over how
- self.assertEqual(self.sqlCtx.createDataFrame(
+ self.assertEqual(self.spark.createDataFrame(
[(u'Alice', 50, None)], schema).dropna(
how='any', thresh=2, subset=['name', 'age']).count(),
1)
@@ -1208,33 +1193,33 @@ class SQLTests(ReusedPySparkTestCase):
StructField("height", DoubleType(), True)])
# fillna shouldn't change non-null values
- row = self.sqlCtx.createDataFrame([(u'Alice', 10, 80.1)], schema).fillna(50).first()
+ row = self.spark.createDataFrame([(u'Alice', 10, 80.1)], schema).fillna(50).first()
self.assertEqual(row.age, 10)
# fillna with int
- row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50).first()
+ row = self.spark.createDataFrame([(u'Alice', None, None)], schema).fillna(50).first()
self.assertEqual(row.age, 50)
self.assertEqual(row.height, 50.0)
# fillna with double
- row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50.1).first()
+ row = self.spark.createDataFrame([(u'Alice', None, None)], schema).fillna(50.1).first()
self.assertEqual(row.age, 50)
self.assertEqual(row.height, 50.1)
# fillna with string
- row = self.sqlCtx.createDataFrame([(None, None, None)], schema).fillna("hello").first()
+ row = self.spark.createDataFrame([(None, None, None)], schema).fillna("hello").first()
self.assertEqual(row.name, u"hello")
self.assertEqual(row.age, None)
# fillna with subset specified for numeric cols
- row = self.sqlCtx.createDataFrame(
+ row = self.spark.createDataFrame(
[(None, None, None)], schema).fillna(50, subset=['name', 'age']).first()
self.assertEqual(row.name, None)
self.assertEqual(row.age, 50)
self.assertEqual(row.height, None)
# fillna with subset specified for numeric cols
- row = self.sqlCtx.createDataFrame(
+ row = self.spark.createDataFrame(
[(None, None, None)], schema).fillna("haha", subset=['name', 'age']).first()
self.assertEqual(row.name, "haha")
self.assertEqual(row.age, None)
@@ -1243,7 +1228,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_bitwise_operations(self):
from pyspark.sql import functions
row = Row(a=170, b=75)
- df = self.sqlCtx.createDataFrame([row])
+ df = self.spark.createDataFrame([row])
result = df.select(df.a.bitwiseAND(df.b)).collect()[0].asDict()
self.assertEqual(170 & 75, result['(a & b)'])
result = df.select(df.a.bitwiseOR(df.b)).collect()[0].asDict()
@@ -1256,7 +1241,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_expr(self):
from pyspark.sql import functions
row = Row(a="length string", b=75)
- df = self.sqlCtx.createDataFrame([row])
+ df = self.spark.createDataFrame([row])
result = df.select(functions.expr("length(a)")).collect()[0].asDict()
self.assertEqual(13, result["length(a)"])
@@ -1267,58 +1252,58 @@ class SQLTests(ReusedPySparkTestCase):
StructField("height", DoubleType(), True)])
# replace with int
- row = self.sqlCtx.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first()
+ row = self.spark.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first()
self.assertEqual(row.age, 20)
self.assertEqual(row.height, 20.0)
# replace with double
- row = self.sqlCtx.createDataFrame(
+ row = self.spark.createDataFrame(
[(u'Alice', 80, 80.0)], schema).replace(80.0, 82.1).first()
self.assertEqual(row.age, 82)
self.assertEqual(row.height, 82.1)
# replace with string
- row = self.sqlCtx.createDataFrame(
+ row = self.spark.createDataFrame(
[(u'Alice', 10, 80.1)], schema).replace(u'Alice', u'Ann').first()
self.assertEqual(row.name, u"Ann")
self.assertEqual(row.age, 10)
# replace with subset specified by a string of a column name w/ actual change
- row = self.sqlCtx.createDataFrame(
+ row = self.spark.createDataFrame(
[(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='age').first()
self.assertEqual(row.age, 20)
# replace with subset specified by a string of a column name w/o actual change
- row = self.sqlCtx.createDataFrame(
+ row = self.spark.createDataFrame(
[(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='height').first()
self.assertEqual(row.age, 10)
# replace with subset specified with one column replaced, another column not in subset
# stays unchanged.
- row = self.sqlCtx.createDataFrame(
+ row = self.spark.createDataFrame(
[(u'Alice', 10, 10.0)], schema).replace(10, 20, subset=['name', 'age']).first()
self.assertEqual(row.name, u'Alice')
self.assertEqual(row.age, 20)
self.assertEqual(row.height, 10.0)
# replace with subset specified but no column will be replaced
- row = self.sqlCtx.createDataFrame(
+ row = self.spark.createDataFrame(
[(u'Alice', 10, None)], schema).replace(10, 20, subset=['name', 'height']).first()
self.assertEqual(row.name, u'Alice')
self.assertEqual(row.age, 10)
self.assertEqual(row.height, None)
def test_capture_analysis_exception(self):
- self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc"))
+ self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc"))
self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b"))
def test_capture_parse_exception(self):
- self.assertRaises(ParseException, lambda: self.sqlCtx.sql("abc"))
+ self.assertRaises(ParseException, lambda: self.spark.sql("abc"))
def test_capture_illegalargument_exception(self):
self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks",
- lambda: self.sqlCtx.sql("SET mapred.reduce.tasks=-1"))
- df = self.sqlCtx.createDataFrame([(1, 2)], ["a", "b"])
+ lambda: self.spark.sql("SET mapred.reduce.tasks=-1"))
+ df = self.spark.createDataFrame([(1, 2)], ["a", "b"])
self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values",
lambda: df.select(sha2(df.a, 1024)).collect())
try:
@@ -1345,8 +1330,8 @@ class SQLTests(ReusedPySparkTestCase):
def test_functions_broadcast(self):
from pyspark.sql.functions import broadcast
- df1 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
- df2 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
+ df1 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
+ df2 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
# equijoin - should be converted into broadcast join
plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan()
@@ -1396,9 +1381,9 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
def test_conf(self):
- spark = self.sparkSession
+ spark = self.spark
spark.conf.set("bogo", "sipeo")
- self.assertEqual(self.sparkSession.conf.get("bogo"), "sipeo")
+ self.assertEqual(spark.conf.get("bogo"), "sipeo")
spark.conf.set("bogo", "ta")
self.assertEqual(spark.conf.get("bogo"), "ta")
self.assertEqual(spark.conf.get("bogo", "not.read"), "ta")
@@ -1408,7 +1393,7 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia")
def test_current_database(self):
- spark = self.sparkSession
+ spark = self.spark
spark.catalog._reset()
self.assertEquals(spark.catalog.currentDatabase(), "default")
spark.sql("CREATE DATABASE some_db")
@@ -1420,7 +1405,7 @@ class SQLTests(ReusedPySparkTestCase):
lambda: spark.catalog.setCurrentDatabase("does_not_exist"))
def test_list_databases(self):
- spark = self.sparkSession
+ spark = self.spark
spark.catalog._reset()
databases = [db.name for db in spark.catalog.listDatabases()]
self.assertEquals(databases, ["default"])
@@ -1430,7 +1415,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_list_tables(self):
from pyspark.sql.catalog import Table
- spark = self.sparkSession
+ spark = self.spark
spark.catalog._reset()
spark.sql("CREATE DATABASE some_db")
self.assertEquals(spark.catalog.listTables(), [])
@@ -1475,7 +1460,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_list_functions(self):
from pyspark.sql.catalog import Function
- spark = self.sparkSession
+ spark = self.spark
spark.catalog._reset()
spark.sql("CREATE DATABASE some_db")
functions = dict((f.name, f) for f in spark.catalog.listFunctions())
@@ -1512,7 +1497,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_list_columns(self):
from pyspark.sql.catalog import Column
- spark = self.sparkSession
+ spark = self.spark
spark.catalog._reset()
spark.sql("CREATE DATABASE some_db")
spark.sql("CREATE TABLE tab1 (name STRING, age INT)")
@@ -1561,7 +1546,7 @@ class SQLTests(ReusedPySparkTestCase):
lambda: spark.catalog.listColumns("does_not_exist"))
def test_cache(self):
- spark = self.sparkSession
+ spark = self.spark
spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("tab1")
spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("tab2")
self.assertFalse(spark.catalog.isCached("tab1"))
@@ -1605,7 +1590,7 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
cls.tearDownClass()
raise unittest.SkipTest("Hive is not available")
os.unlink(cls.tempdir.name)
- cls.sqlCtx = HiveContext._createForTesting(cls.sc)
+ cls.spark = HiveContext._createForTesting(cls.sc)
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
cls.df = cls.sc.parallelize(cls.testData).toDF()
@@ -1619,45 +1604,45 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
df.write.saveAsTable("savedJsonTable", "json", "append", path=tmpPath)
- actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath, "json")
+ actual = self.spark.createExternalTable("externalJsonTable", tmpPath, "json")
self.assertEqual(sorted(df.collect()),
- sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
+ sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect()))
self.assertEqual(sorted(df.collect()),
- sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
+ sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect()))
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
- self.sqlCtx.sql("DROP TABLE externalJsonTable")
+ self.spark.sql("DROP TABLE externalJsonTable")
df.write.saveAsTable("savedJsonTable", "json", "overwrite", path=tmpPath)
schema = StructType([StructField("value", StringType(), True)])
- actual = self.sqlCtx.createExternalTable("externalJsonTable", source="json",
- schema=schema, path=tmpPath,
- noUse="this options will not be used")
+ actual = self.spark.createExternalTable("externalJsonTable", source="json",
+ schema=schema, path=tmpPath,
+ noUse="this options will not be used")
self.assertEqual(sorted(df.collect()),
- sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
+ sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect()))
self.assertEqual(sorted(df.select("value").collect()),
- sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
+ sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect()))
self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
- self.sqlCtx.sql("DROP TABLE savedJsonTable")
- self.sqlCtx.sql("DROP TABLE externalJsonTable")
+ self.spark.sql("DROP TABLE savedJsonTable")
+ self.spark.sql("DROP TABLE externalJsonTable")
- defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
- "org.apache.spark.sql.parquet")
- self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
+ defaultDataSourceName = self.spark.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
- actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath)
+ actual = self.spark.createExternalTable("externalJsonTable", path=tmpPath)
self.assertEqual(sorted(df.collect()),
- sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
+ sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect()))
self.assertEqual(sorted(df.collect()),
- sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
+ sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect()))
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
- self.sqlCtx.sql("DROP TABLE savedJsonTable")
- self.sqlCtx.sql("DROP TABLE externalJsonTable")
- self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
+ self.spark.sql("DROP TABLE savedJsonTable")
+ self.spark.sql("DROP TABLE externalJsonTable")
+ self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
shutil.rmtree(tmpPath)
def test_window_functions(self):
- df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
+ df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
w = Window.partitionBy("value").orderBy("key")
from pyspark.sql import functions as F
sel = df.select(df.value, df.key,
@@ -1679,7 +1664,7 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
self.assertEqual(tuple(r), ex[:len(r)])
def test_window_functions_without_partitionBy(self):
- df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
+ df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
w = Window.orderBy("key", df.value)
from pyspark.sql import functions as F
sel = df.select(df.value, df.key,
@@ -1701,7 +1686,7 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
self.assertEqual(tuple(r), ex[:len(r)])
def test_collect_functions(self):
- df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
+ df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
from pyspark.sql import functions
self.assertEqual(