diff options
author | Sandeep Singh <sandeep@techaddict.me> | 2016-05-11 11:24:16 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-05-11 11:24:16 -0700 |
commit | 29314379729de4082bd2297c9e5289e3e4a0115e (patch) | |
tree | a5aede7207fde856910581f7f97f4b65b73a6e39 /python/pyspark/sql/tests.py | |
parent | d8935db5ecb7c959585411da9bf1e9a9c4d5cb37 (diff) | |
download | spark-29314379729de4082bd2297c9e5289e3e4a0115e.tar.gz spark-29314379729de4082bd2297c9e5289e3e4a0115e.tar.bz2 spark-29314379729de4082bd2297c9e5289e3e4a0115e.zip |
[SPARK-15037] [SQL] [MLLIB] Part2: Use SparkSession instead of SQLContext in Python TestSuites
## What changes were proposed in this pull request?
Use SparkSession instead of SQLContext in Python TestSuites
## How was this patch tested?
Existing tests
Author: Sandeep Singh <sandeep@techaddict.me>
Closes #13044 from techaddict/SPARK-15037-python.
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r-- | python/pyspark/sql/tests.py | 379 |
1 files changed, 182 insertions, 197 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index cd5c4a7b3e..0c73f58c3b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -45,7 +45,7 @@ if sys.version_info[:2] <= (2, 6): else: import unittest -from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row +from pyspark.sql import SparkSession, HiveContext, Column, Row from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type from pyspark.tests import ReusedPySparkTestCase @@ -178,20 +178,6 @@ class DataTypeTests(unittest.TestCase): self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1)) -class SQLContextTests(ReusedPySparkTestCase): - def test_get_or_create(self): - sqlCtx = SQLContext.getOrCreate(self.sc) - self.assertTrue(SQLContext.getOrCreate(self.sc) is sqlCtx) - - def test_new_session(self): - sqlCtx = SQLContext.getOrCreate(self.sc) - sqlCtx.setConf("test_key", "a") - sqlCtx2 = sqlCtx.newSession() - sqlCtx2.setConf("test_key", "b") - self.assertEqual(sqlCtx.getConf("test_key", ""), "a") - self.assertEqual(sqlCtx2.getConf("test_key", ""), "b") - - class SQLTests(ReusedPySparkTestCase): @classmethod @@ -199,15 +185,14 @@ class SQLTests(ReusedPySparkTestCase): ReusedPySparkTestCase.setUpClass() cls.tempdir = tempfile.NamedTemporaryFile(delete=False) os.unlink(cls.tempdir.name) - cls.sparkSession = SparkSession(cls.sc) - cls.sqlCtx = cls.sparkSession._wrapped + cls.spark = SparkSession(cls.sc) cls.testData = [Row(key=i, value=str(i)) for i in range(100)] - rdd = cls.sc.parallelize(cls.testData, 2) - cls.df = rdd.toDF() + cls.df = cls.spark.createDataFrame(cls.testData) @classmethod def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() + cls.spark.stop() shutil.rmtree(cls.tempdir.name, ignore_errors=True) def test_row_should_be_read_only(self): @@ -218,7 +203,7 @@ class SQLTests(ReusedPySparkTestCase): row.a = 3 self.assertRaises(Exception, foo) - row2 = self.sqlCtx.range(10).first() + row2 = self.spark.range(10).first() self.assertEqual(0, row2.id) def foo2(): @@ -226,14 +211,14 @@ class SQLTests(ReusedPySparkTestCase): self.assertRaises(Exception, foo2) def test_range(self): - self.assertEqual(self.sqlCtx.range(1, 1).count(), 0) - self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1) - self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2) - self.assertEqual(self.sqlCtx.range(-2).count(), 0) - self.assertEqual(self.sqlCtx.range(3).count(), 3) + self.assertEqual(self.spark.range(1, 1).count(), 0) + self.assertEqual(self.spark.range(1, 0, -1).count(), 1) + self.assertEqual(self.spark.range(0, 1 << 40, 1 << 39).count(), 2) + self.assertEqual(self.spark.range(-2).count(), 0) + self.assertEqual(self.spark.range(3).count(), 3) def test_duplicated_column_names(self): - df = self.sqlCtx.createDataFrame([(1, 2)], ["c", "c"]) + df = self.spark.createDataFrame([(1, 2)], ["c", "c"]) row = df.select('*').first() self.assertEqual(1, row[0]) self.assertEqual(2, row[1]) @@ -247,7 +232,7 @@ class SQLTests(ReusedPySparkTestCase): from pyspark.sql.functions import explode d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})] rdd = self.sc.parallelize(d) - data = self.sqlCtx.createDataFrame(rdd) + data = self.spark.createDataFrame(rdd) result = data.select(explode(data.intlist).alias("a")).select("a").collect() self.assertEqual(result[0][0], 1) @@ -269,7 +254,7 @@ class SQLTests(ReusedPySparkTestCase): def test_udf_with_callable(self): d = [Row(number=i, squared=i**2) for i in range(10)] rdd = self.sc.parallelize(d) - data = self.sqlCtx.createDataFrame(rdd) + data = self.spark.createDataFrame(rdd) class PlusFour: def __call__(self, col): @@ -284,7 +269,7 @@ class SQLTests(ReusedPySparkTestCase): def test_udf_with_partial_function(self): d = [Row(number=i, squared=i**2) for i in range(10)] rdd = self.sc.parallelize(d) - data = self.sqlCtx.createDataFrame(rdd) + data = self.spark.createDataFrame(rdd) def some_func(col, param): if col is not None: @@ -296,56 +281,56 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) def test_udf(self): - self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) - [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() + self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) + [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() self.assertEqual(row[0], 5) def test_udf2(self): - self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType()) - self.sqlCtx.createDataFrame(self.sc.parallelize([Row(a="test")])).registerTempTable("test") - [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() + self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType()) + self.spark.createDataFrame(self.sc.parallelize([Row(a="test")])).registerTempTable("test") + [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() self.assertEqual(4, res[0]) def test_chained_udf(self): - self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType()) - [row] = self.sqlCtx.sql("SELECT double(1)").collect() + self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType()) + [row] = self.spark.sql("SELECT double(1)").collect() self.assertEqual(row[0], 2) - [row] = self.sqlCtx.sql("SELECT double(double(1))").collect() + [row] = self.spark.sql("SELECT double(double(1))").collect() self.assertEqual(row[0], 4) - [row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect() + [row] = self.spark.sql("SELECT double(double(1) + 1)").collect() self.assertEqual(row[0], 6) def test_multiple_udfs(self): - self.sqlCtx.registerFunction("double", lambda x: x * 2, IntegerType()) - [row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect() + self.spark.catalog.registerFunction("double", lambda x: x * 2, IntegerType()) + [row] = self.spark.sql("SELECT double(1), double(2)").collect() self.assertEqual(tuple(row), (2, 4)) - [row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect() + [row] = self.spark.sql("SELECT double(double(1)), double(double(2) + 2)").collect() self.assertEqual(tuple(row), (4, 12)) - self.sqlCtx.registerFunction("add", lambda x, y: x + y, IntegerType()) - [row] = self.sqlCtx.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect() + self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType()) + [row] = self.spark.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect() self.assertEqual(tuple(row), (6, 5)) def test_udf_with_array_type(self): d = [Row(l=list(range(3)), d={"key": list(range(5))})] rdd = self.sc.parallelize(d) - self.sqlCtx.createDataFrame(rdd).registerTempTable("test") - self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) - self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) - [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() + self.spark.createDataFrame(rdd).registerTempTable("test") + self.spark.catalog.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) + self.spark.catalog.registerFunction("maplen", lambda d: len(d), IntegerType()) + [(l1, l2)] = self.spark.sql("select copylist(l), maplen(d) from test").collect() self.assertEqual(list(range(3)), l1) self.assertEqual(1, l2) def test_broadcast_in_udf(self): bar = {"a": "aa", "b": "bb", "c": "abc"} foo = self.sc.broadcast(bar) - self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') - [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() + self.spark.catalog.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') + [res] = self.spark.sql("SELECT MYUDF('c')").collect() self.assertEqual("abc", res[0]) - [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() + [res] = self.spark.sql("SELECT MYUDF('')").collect() self.assertEqual("", res[0]) def test_udf_with_aggregate_function(self): - df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) from pyspark.sql.functions import udf, col from pyspark.sql.types import BooleanType @@ -355,7 +340,7 @@ class SQLTests(ReusedPySparkTestCase): def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.sqlCtx.read.json(rdd) + df = self.spark.read.json(rdd) df.count() df.collect() df.schema @@ -369,41 +354,41 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(2, df.count()) df.registerTempTable("temp") - df = self.sqlCtx.sql("select foo from temp") + df = self.spark.sql("select foo from temp") df.count() df.collect() def test_apply_schema_to_row(self): - df = self.sqlCtx.read.json(self.sc.parallelize(["""{"a":2}"""])) - df2 = self.sqlCtx.createDataFrame(df.rdd.map(lambda x: x), df.schema) + df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""])) + df2 = self.spark.createDataFrame(df.rdd.map(lambda x: x), df.schema) self.assertEqual(df.collect(), df2.collect()) rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) - df3 = self.sqlCtx.createDataFrame(rdd, df.schema) + df3 = self.spark.createDataFrame(rdd, df.schema) self.assertEqual(10, df3.count()) def test_infer_schema_to_local(self): input = [{"a": 1}, {"b": "coffee"}] rdd = self.sc.parallelize(input) - df = self.sqlCtx.createDataFrame(input) - df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0) + df = self.spark.createDataFrame(input) + df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0) self.assertEqual(df.schema, df2.schema) rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None)) - df3 = self.sqlCtx.createDataFrame(rdd, df.schema) + df3 = self.spark.createDataFrame(rdd, df.schema) self.assertEqual(10, df3.count()) def test_create_dataframe_schema_mismatch(self): input = [Row(a=1)] rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i)) schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())]) - df = self.sqlCtx.createDataFrame(rdd, schema) + df = self.spark.createDataFrame(rdd, schema) self.assertRaises(Exception, lambda: df.show()) def test_serialize_nested_array_and_map(self): d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] rdd = self.sc.parallelize(d) - df = self.sqlCtx.createDataFrame(rdd) + df = self.spark.createDataFrame(rdd) row = df.head() self.assertEqual(1, len(row.l)) self.assertEqual(1, row.l[0].a) @@ -425,31 +410,31 @@ class SQLTests(ReusedPySparkTestCase): d = [Row(l=[], d={}, s=None), Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] rdd = self.sc.parallelize(d) - df = self.sqlCtx.createDataFrame(rdd) + df = self.spark.createDataFrame(rdd) self.assertEqual([], df.rdd.map(lambda r: r.l).first()) self.assertEqual([None, ""], df.rdd.map(lambda r: r.s).collect()) df.registerTempTable("test") - result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") + result = self.spark.sql("SELECT l[0].a from test where d['key'].d = '2'") self.assertEqual(1, result.head()[0]) - df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0) + df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0) self.assertEqual(df.schema, df2.schema) self.assertEqual({}, df2.rdd.map(lambda r: r.d).first()) self.assertEqual([None, ""], df2.rdd.map(lambda r: r.s).collect()) df2.registerTempTable("test2") - result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") + result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'") self.assertEqual(1, result.head()[0]) def test_infer_nested_schema(self): NestedRow = Row("f1", "f2") nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}), NestedRow([2, 3], {"row2": 2.0})]) - df = self.sqlCtx.createDataFrame(nestedRdd1) + df = self.spark.createDataFrame(nestedRdd1) self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0]) nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]), NestedRow([[2, 3], [3, 4]], [2, 3])]) - df = self.sqlCtx.createDataFrame(nestedRdd2) + df = self.spark.createDataFrame(nestedRdd2) self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0]) from collections import namedtuple @@ -457,17 +442,17 @@ class SQLTests(ReusedPySparkTestCase): rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"), CustomRow(field1=2, field2="row2"), CustomRow(field1=3, field2="row3")]) - df = self.sqlCtx.createDataFrame(rdd) + df = self.spark.createDataFrame(rdd) self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) def test_create_dataframe_from_objects(self): data = [MyObject(1, "1"), MyObject(2, "2")] - df = self.sqlCtx.createDataFrame(data) + df = self.spark.createDataFrame(data) self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")]) self.assertEqual(df.first(), Row(key=1, value="1")) def test_select_null_literal(self): - df = self.sqlCtx.sql("select null as col") + df = self.spark.sql("select null as col") self.assertEqual(Row(col=None), df.first()) def test_apply_schema(self): @@ -488,7 +473,7 @@ class SQLTests(ReusedPySparkTestCase): StructField("struct1", StructType([StructField("b", ShortType(), False)]), False), StructField("list1", ArrayType(ByteType(), False), False), StructField("null1", DoubleType(), True)]) - df = self.sqlCtx.createDataFrame(rdd, schema) + df = self.spark.createDataFrame(rdd, schema) results = df.rdd.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1, x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1)) r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1), @@ -496,9 +481,9 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(r, results.first()) df.registerTempTable("table2") - r = self.sqlCtx.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + - "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " + - "float1 + 1.5 as float1 FROM table2").first() + r = self.spark.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + + "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " + + "float1 + 1.5 as float1 FROM table2").first() self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r)) @@ -508,7 +493,7 @@ class SQLTests(ReusedPySparkTestCase): abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]" schema = _parse_schema_abstract(abstract) typedSchema = _infer_schema_type(rdd.first(), schema) - df = self.sqlCtx.createDataFrame(rdd, typedSchema) + df = self.spark.createDataFrame(rdd, typedSchema) r = (127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), {"a": 1}, Row(b=2), [1, 2, 3]) self.assertEqual(r, tuple(df.first())) @@ -524,7 +509,7 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(1, row.asDict()['l'][0].a) df = self.sc.parallelize([row]).toDF() df.registerTempTable("test") - row = self.sqlCtx.sql("select l, d from test").head() + row = self.spark.sql("select l, d from test").head() self.assertEqual(1, row.asDict()["l"][0].a) self.assertEqual(1.0, row.asDict()['d']['key'].c) @@ -535,7 +520,7 @@ class SQLTests(ReusedPySparkTestCase): def check_datatype(datatype): pickled = pickle.loads(pickle.dumps(datatype)) assert datatype == pickled - scala_datatype = self.sqlCtx._ssql_ctx.parseDataType(datatype.json()) + scala_datatype = self.spark._wrapped._ssql_ctx.parseDataType(datatype.json()) python_datatype = _parse_datatype_json_string(scala_datatype.json()) assert datatype == python_datatype @@ -560,21 +545,21 @@ class SQLTests(ReusedPySparkTestCase): def test_infer_schema_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df = self.sqlCtx.createDataFrame([row]) + df = self.spark.createDataFrame([row]) schema = df.schema field = [f for f in schema.fields if f.name == "point"][0] self.assertEqual(type(field.dataType), ExamplePointUDT) df.registerTempTable("labeled_point") - point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point + point = self.spark.sql("SELECT point FROM labeled_point").head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) - df = self.sqlCtx.createDataFrame([row]) + df = self.spark.createDataFrame([row]) schema = df.schema field = [f for f in schema.fields if f.name == "point"][0] self.assertEqual(type(field.dataType), PythonOnlyUDT) df.registerTempTable("labeled_point") - point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point + point = self.spark.sql("SELECT point FROM labeled_point").head().point self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) def test_apply_schema_with_udt(self): @@ -582,21 +567,21 @@ class SQLTests(ReusedPySparkTestCase): row = (1.0, ExamplePoint(1.0, 2.0)) schema = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) - df = self.sqlCtx.createDataFrame([row], schema) + df = self.spark.createDataFrame([row], schema) point = df.head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = (1.0, PythonOnlyPoint(1.0, 2.0)) schema = StructType([StructField("label", DoubleType(), False), StructField("point", PythonOnlyUDT(), False)]) - df = self.sqlCtx.createDataFrame([row], schema) + df = self.spark.createDataFrame([row], schema) point = df.head().point self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) def test_udf_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df = self.sqlCtx.createDataFrame([row]) + df = self.spark.createDataFrame([row]) self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) udf = UserDefinedFunction(lambda p: p.y, DoubleType()) self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) @@ -604,7 +589,7 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) - df = self.sqlCtx.createDataFrame([row]) + df = self.spark.createDataFrame([row]) self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) udf = UserDefinedFunction(lambda p: p.y, DoubleType()) self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) @@ -614,17 +599,17 @@ class SQLTests(ReusedPySparkTestCase): def test_parquet_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df0 = self.sqlCtx.createDataFrame([row]) + df0 = self.spark.createDataFrame([row]) output_dir = os.path.join(self.tempdir.name, "labeled_point") df0.write.parquet(output_dir) - df1 = self.sqlCtx.read.parquet(output_dir) + df1 = self.spark.read.parquet(output_dir) point = df1.head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) - df0 = self.sqlCtx.createDataFrame([row]) + df0 = self.spark.createDataFrame([row]) df0.write.parquet(output_dir, mode='overwrite') - df1 = self.sqlCtx.read.parquet(output_dir) + df1 = self.spark.read.parquet(output_dir) point = df1.head().point self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) @@ -634,8 +619,8 @@ class SQLTests(ReusedPySparkTestCase): row2 = (2.0, ExamplePoint(3.0, 4.0)) schema = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) - df1 = self.sqlCtx.createDataFrame([row1], schema) - df2 = self.sqlCtx.createDataFrame([row2], schema) + df1 = self.spark.createDataFrame([row1], schema) + df2 = self.spark.createDataFrame([row2], schema) result = df1.union(df2).orderBy("label").collect() self.assertEqual( @@ -688,7 +673,7 @@ class SQLTests(ReusedPySparkTestCase): def test_first_last_ignorenulls(self): from pyspark.sql import functions - df = self.sqlCtx.range(0, 100) + df = self.spark.range(0, 100) df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id")) df3 = df2.select(functions.first(df2.id, False).alias('a'), functions.first(df2.id, True).alias('b'), @@ -829,36 +814,36 @@ class SQLTests(ReusedPySparkTestCase): schema = StructType([StructField("f1", StringType(), True, None), StructField("f2", StringType(), True, {'a': None})]) rdd = self.sc.parallelize([["a", "b"], ["c", "d"]]) - self.sqlCtx.createDataFrame(rdd, schema) + self.spark.createDataFrame(rdd, schema) def test_save_and_load(self): df = self.df tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) df.write.json(tmpPath) - actual = self.sqlCtx.read.json(tmpPath) + actual = self.spark.read.json(tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) schema = StructType([StructField("value", StringType(), True)]) - actual = self.sqlCtx.read.json(tmpPath, schema) + actual = self.spark.read.json(tmpPath, schema) self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) df.write.json(tmpPath, "overwrite") - actual = self.sqlCtx.read.json(tmpPath) + actual = self.spark.read.json(tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) df.write.save(format="json", mode="overwrite", path=tmpPath, noUse="this options will not be used in save.") - actual = self.sqlCtx.read.load(format="json", path=tmpPath, - noUse="this options will not be used in load.") + actual = self.spark.read.load(format="json", path=tmpPath, + noUse="this options will not be used in load.") self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", + defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default", "org.apache.spark.sql.parquet") - self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - actual = self.sqlCtx.read.load(path=tmpPath) + self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + actual = self.spark.read.load(path=tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) csvpath = os.path.join(tempfile.mkdtemp(), 'data') df.write.option('quote', None).format('csv').save(csvpath) @@ -870,36 +855,36 @@ class SQLTests(ReusedPySparkTestCase): tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) df.write.json(tmpPath) - actual = self.sqlCtx.read.json(tmpPath) + actual = self.spark.read.json(tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) schema = StructType([StructField("value", StringType(), True)]) - actual = self.sqlCtx.read.json(tmpPath, schema) + actual = self.spark.read.json(tmpPath, schema) self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) df.write.mode("overwrite").json(tmpPath) - actual = self.sqlCtx.read.json(tmpPath) + actual = self.spark.read.json(tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) df.write.mode("overwrite").options(noUse="this options will not be used in save.")\ .option("noUse", "this option will not be used in save.")\ .format("json").save(path=tmpPath) actual =\ - self.sqlCtx.read.format("json")\ - .load(path=tmpPath, noUse="this options will not be used in load.") + self.spark.read.format("json")\ + .load(path=tmpPath, noUse="this options will not be used in load.") self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", + defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default", "org.apache.spark.sql.parquet") - self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - actual = self.sqlCtx.read.load(path=tmpPath) + self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + actual = self.spark.read.load(path=tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) shutil.rmtree(tmpPath) def test_stream_trigger_takes_keyword_args(self): - df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') + df = self.spark.read.format('text').stream('python/test_support/sql/streaming') try: df.write.trigger('5 seconds') self.fail("Should have thrown an exception") @@ -909,7 +894,7 @@ class SQLTests(ReusedPySparkTestCase): def test_stream_read_options(self): schema = StructType([StructField("data", StringType(), False)]) - df = self.sqlCtx.read.format('text').option('path', 'python/test_support/sql/streaming')\ + df = self.spark.read.format('text').option('path', 'python/test_support/sql/streaming')\ .schema(schema).stream() self.assertTrue(df.isStreaming) self.assertEqual(df.schema.simpleString(), "struct<data:string>") @@ -917,15 +902,15 @@ class SQLTests(ReusedPySparkTestCase): def test_stream_read_options_overwrite(self): bad_schema = StructType([StructField("test", IntegerType(), False)]) schema = StructType([StructField("data", StringType(), False)]) - df = self.sqlCtx.read.format('csv').option('path', 'python/test_support/sql/fake') \ + df = self.spark.read.format('csv').option('path', 'python/test_support/sql/fake') \ .schema(bad_schema).stream(path='python/test_support/sql/streaming', schema=schema, format='text') self.assertTrue(df.isStreaming) self.assertEqual(df.schema.simpleString(), "struct<data:string>") def test_stream_save_options(self): - df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') - for cq in self.sqlCtx.streams.active: + df = self.spark.read.format('text').stream('python/test_support/sql/streaming') + for cq in self.spark._wrapped.streams.active: cq.stop() tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) @@ -948,8 +933,8 @@ class SQLTests(ReusedPySparkTestCase): shutil.rmtree(tmpPath) def test_stream_save_options_overwrite(self): - df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') - for cq in self.sqlCtx.streams.active: + df = self.spark.read.format('text').stream('python/test_support/sql/streaming') + for cq in self.spark._wrapped.streams.active: cq.stop() tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) @@ -977,8 +962,8 @@ class SQLTests(ReusedPySparkTestCase): shutil.rmtree(tmpPath) def test_stream_await_termination(self): - df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') - for cq in self.sqlCtx.streams.active: + df = self.spark.read.format('text').stream('python/test_support/sql/streaming') + for cq in self.spark._wrapped.streams.active: cq.stop() tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) @@ -1005,8 +990,8 @@ class SQLTests(ReusedPySparkTestCase): shutil.rmtree(tmpPath) def test_query_manager_await_termination(self): - df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') - for cq in self.sqlCtx.streams.active: + df = self.spark.read.format('text').stream('python/test_support/sql/streaming') + for cq in self.spark._wrapped.streams.active: cq.stop() tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) @@ -1018,13 +1003,13 @@ class SQLTests(ReusedPySparkTestCase): try: self.assertTrue(cq.isActive) try: - self.sqlCtx.streams.awaitAnyTermination("hello") + self.spark._wrapped.streams.awaitAnyTermination("hello") self.fail("Expected a value exception") except ValueError: pass now = time.time() # test should take at least 2 seconds - res = self.sqlCtx.streams.awaitAnyTermination(2.6) + res = self.spark._wrapped.streams.awaitAnyTermination(2.6) duration = time.time() - now self.assertTrue(duration >= 2) self.assertFalse(res) @@ -1035,7 +1020,7 @@ class SQLTests(ReusedPySparkTestCase): def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.sqlCtx.read.json(rdd) + df = self.spark.read.json(rdd) # render_doc() reproduces the help() exception without printing output pydoc.render_doc(df) pydoc.render_doc(df.foo) @@ -1051,7 +1036,7 @@ class SQLTests(ReusedPySparkTestCase): self.assertRaises(TypeError, lambda: df[{}]) def test_column_name_with_non_ascii(self): - df = self.sqlCtx.createDataFrame([(1,)], ["数量"]) + df = self.spark.createDataFrame([(1,)], ["数量"]) self.assertEqual(StructType([StructField("数量", LongType(), True)]), df.schema) self.assertEqual("DataFrame[数量: bigint]", str(df)) self.assertEqual([("数量", 'bigint')], df.dtypes) @@ -1084,7 +1069,7 @@ class SQLTests(ReusedPySparkTestCase): # this saving as Parquet caused issues as well. output_dir = os.path.join(self.tempdir.name, "infer_long_type") df.write.parquet(output_dir) - df1 = self.sqlCtx.read.parquet(output_dir) + df1 = self.spark.read.parquet(output_dir) self.assertEqual('a', df1.first().f1) self.assertEqual(100000000000000, df1.first().f2) @@ -1100,7 +1085,7 @@ class SQLTests(ReusedPySparkTestCase): time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000) date = time.date() row = Row(date=date, time=time) - df = self.sqlCtx.createDataFrame([row]) + df = self.spark.createDataFrame([row]) self.assertEqual(1, df.filter(df.date == date).count()) self.assertEqual(1, df.filter(df.time == time).count()) self.assertEqual(0, df.filter(df.date > date).count()) @@ -1110,7 +1095,7 @@ class SQLTests(ReusedPySparkTestCase): dt1 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(0)) dt2 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(1)) row = Row(date=dt1) - df = self.sqlCtx.createDataFrame([row]) + df = self.spark.createDataFrame([row]) self.assertEqual(0, df.filter(df.date == dt2).count()) self.assertEqual(1, df.filter(df.date > dt2).count()) self.assertEqual(0, df.filter(df.date < dt2).count()) @@ -1125,7 +1110,7 @@ class SQLTests(ReusedPySparkTestCase): utcnow = datetime.datetime.utcfromtimestamp(ts) # without microseconds # add microseconds to utcnow (keeping year,month,day,hour,minute,second) utcnow = datetime.datetime(*(utcnow.timetuple()[:6] + (now.microsecond, utc))) - df = self.sqlCtx.createDataFrame([(day, now, utcnow)]) + df = self.spark.createDataFrame([(day, now, utcnow)]) day1, now1, utcnow1 = df.first() self.assertEqual(day1, day) self.assertEqual(now, now1) @@ -1134,13 +1119,13 @@ class SQLTests(ReusedPySparkTestCase): def test_decimal(self): from decimal import Decimal schema = StructType([StructField("decimal", DecimalType(10, 5))]) - df = self.sqlCtx.createDataFrame([(Decimal("3.14159"),)], schema) + df = self.spark.createDataFrame([(Decimal("3.14159"),)], schema) row = df.select(df.decimal + 1).first() self.assertEqual(row[0], Decimal("4.14159")) tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) df.write.parquet(tmpPath) - df2 = self.sqlCtx.read.parquet(tmpPath) + df2 = self.spark.read.parquet(tmpPath) row = df2.first() self.assertEqual(row[0], Decimal("3.14159")) @@ -1151,52 +1136,52 @@ class SQLTests(ReusedPySparkTestCase): StructField("height", DoubleType(), True)]) # shouldn't drop a non-null row - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', 50, 80.1)], schema).dropna().count(), 1) # dropping rows with a single null value - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, 80.1)], schema).dropna().count(), 0) - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, 80.1)], schema).dropna(how='any').count(), 0) # if how = 'all', only drop rows if all values are null - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, 80.1)], schema).dropna(how='all').count(), 1) - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(None, None, None)], schema).dropna(how='all').count(), 0) # how and subset - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', 50, None)], schema).dropna(how='any', subset=['name', 'age']).count(), 1) - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, None)], schema).dropna(how='any', subset=['name', 'age']).count(), 0) # threshold - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, 80.1)], schema).dropna(thresh=2).count(), 1) - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, None)], schema).dropna(thresh=2).count(), 0) # threshold and subset - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', 50, None)], schema).dropna(thresh=2, subset=['name', 'age']).count(), 1) - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, 180.9)], schema).dropna(thresh=2, subset=['name', 'age']).count(), 0) # thresh should take precedence over how - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', 50, None)], schema).dropna( how='any', thresh=2, subset=['name', 'age']).count(), 1) @@ -1208,33 +1193,33 @@ class SQLTests(ReusedPySparkTestCase): StructField("height", DoubleType(), True)]) # fillna shouldn't change non-null values - row = self.sqlCtx.createDataFrame([(u'Alice', 10, 80.1)], schema).fillna(50).first() + row = self.spark.createDataFrame([(u'Alice', 10, 80.1)], schema).fillna(50).first() self.assertEqual(row.age, 10) # fillna with int - row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50).first() + row = self.spark.createDataFrame([(u'Alice', None, None)], schema).fillna(50).first() self.assertEqual(row.age, 50) self.assertEqual(row.height, 50.0) # fillna with double - row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50.1).first() + row = self.spark.createDataFrame([(u'Alice', None, None)], schema).fillna(50.1).first() self.assertEqual(row.age, 50) self.assertEqual(row.height, 50.1) # fillna with string - row = self.sqlCtx.createDataFrame([(None, None, None)], schema).fillna("hello").first() + row = self.spark.createDataFrame([(None, None, None)], schema).fillna("hello").first() self.assertEqual(row.name, u"hello") self.assertEqual(row.age, None) # fillna with subset specified for numeric cols - row = self.sqlCtx.createDataFrame( + row = self.spark.createDataFrame( [(None, None, None)], schema).fillna(50, subset=['name', 'age']).first() self.assertEqual(row.name, None) self.assertEqual(row.age, 50) self.assertEqual(row.height, None) # fillna with subset specified for numeric cols - row = self.sqlCtx.createDataFrame( + row = self.spark.createDataFrame( [(None, None, None)], schema).fillna("haha", subset=['name', 'age']).first() self.assertEqual(row.name, "haha") self.assertEqual(row.age, None) @@ -1243,7 +1228,7 @@ class SQLTests(ReusedPySparkTestCase): def test_bitwise_operations(self): from pyspark.sql import functions row = Row(a=170, b=75) - df = self.sqlCtx.createDataFrame([row]) + df = self.spark.createDataFrame([row]) result = df.select(df.a.bitwiseAND(df.b)).collect()[0].asDict() self.assertEqual(170 & 75, result['(a & b)']) result = df.select(df.a.bitwiseOR(df.b)).collect()[0].asDict() @@ -1256,7 +1241,7 @@ class SQLTests(ReusedPySparkTestCase): def test_expr(self): from pyspark.sql import functions row = Row(a="length string", b=75) - df = self.sqlCtx.createDataFrame([row]) + df = self.spark.createDataFrame([row]) result = df.select(functions.expr("length(a)")).collect()[0].asDict() self.assertEqual(13, result["length(a)"]) @@ -1267,58 +1252,58 @@ class SQLTests(ReusedPySparkTestCase): StructField("height", DoubleType(), True)]) # replace with int - row = self.sqlCtx.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first() + row = self.spark.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first() self.assertEqual(row.age, 20) self.assertEqual(row.height, 20.0) # replace with double - row = self.sqlCtx.createDataFrame( + row = self.spark.createDataFrame( [(u'Alice', 80, 80.0)], schema).replace(80.0, 82.1).first() self.assertEqual(row.age, 82) self.assertEqual(row.height, 82.1) # replace with string - row = self.sqlCtx.createDataFrame( + row = self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace(u'Alice', u'Ann').first() self.assertEqual(row.name, u"Ann") self.assertEqual(row.age, 10) # replace with subset specified by a string of a column name w/ actual change - row = self.sqlCtx.createDataFrame( + row = self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='age').first() self.assertEqual(row.age, 20) # replace with subset specified by a string of a column name w/o actual change - row = self.sqlCtx.createDataFrame( + row = self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='height').first() self.assertEqual(row.age, 10) # replace with subset specified with one column replaced, another column not in subset # stays unchanged. - row = self.sqlCtx.createDataFrame( + row = self.spark.createDataFrame( [(u'Alice', 10, 10.0)], schema).replace(10, 20, subset=['name', 'age']).first() self.assertEqual(row.name, u'Alice') self.assertEqual(row.age, 20) self.assertEqual(row.height, 10.0) # replace with subset specified but no column will be replaced - row = self.sqlCtx.createDataFrame( + row = self.spark.createDataFrame( [(u'Alice', 10, None)], schema).replace(10, 20, subset=['name', 'height']).first() self.assertEqual(row.name, u'Alice') self.assertEqual(row.age, 10) self.assertEqual(row.height, None) def test_capture_analysis_exception(self): - self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc")) + self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc")) self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) def test_capture_parse_exception(self): - self.assertRaises(ParseException, lambda: self.sqlCtx.sql("abc")) + self.assertRaises(ParseException, lambda: self.spark.sql("abc")) def test_capture_illegalargument_exception(self): self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks", - lambda: self.sqlCtx.sql("SET mapred.reduce.tasks=-1")) - df = self.sqlCtx.createDataFrame([(1, 2)], ["a", "b"]) + lambda: self.spark.sql("SET mapred.reduce.tasks=-1")) + df = self.spark.createDataFrame([(1, 2)], ["a", "b"]) self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values", lambda: df.select(sha2(df.a, 1024)).collect()) try: @@ -1345,8 +1330,8 @@ class SQLTests(ReusedPySparkTestCase): def test_functions_broadcast(self): from pyspark.sql.functions import broadcast - df1 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) - df2 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) + df1 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) + df2 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) # equijoin - should be converted into broadcast join plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan() @@ -1396,9 +1381,9 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) def test_conf(self): - spark = self.sparkSession + spark = self.spark spark.conf.set("bogo", "sipeo") - self.assertEqual(self.sparkSession.conf.get("bogo"), "sipeo") + self.assertEqual(spark.conf.get("bogo"), "sipeo") spark.conf.set("bogo", "ta") self.assertEqual(spark.conf.get("bogo"), "ta") self.assertEqual(spark.conf.get("bogo", "not.read"), "ta") @@ -1408,7 +1393,7 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia") def test_current_database(self): - spark = self.sparkSession + spark = self.spark spark.catalog._reset() self.assertEquals(spark.catalog.currentDatabase(), "default") spark.sql("CREATE DATABASE some_db") @@ -1420,7 +1405,7 @@ class SQLTests(ReusedPySparkTestCase): lambda: spark.catalog.setCurrentDatabase("does_not_exist")) def test_list_databases(self): - spark = self.sparkSession + spark = self.spark spark.catalog._reset() databases = [db.name for db in spark.catalog.listDatabases()] self.assertEquals(databases, ["default"]) @@ -1430,7 +1415,7 @@ class SQLTests(ReusedPySparkTestCase): def test_list_tables(self): from pyspark.sql.catalog import Table - spark = self.sparkSession + spark = self.spark spark.catalog._reset() spark.sql("CREATE DATABASE some_db") self.assertEquals(spark.catalog.listTables(), []) @@ -1475,7 +1460,7 @@ class SQLTests(ReusedPySparkTestCase): def test_list_functions(self): from pyspark.sql.catalog import Function - spark = self.sparkSession + spark = self.spark spark.catalog._reset() spark.sql("CREATE DATABASE some_db") functions = dict((f.name, f) for f in spark.catalog.listFunctions()) @@ -1512,7 +1497,7 @@ class SQLTests(ReusedPySparkTestCase): def test_list_columns(self): from pyspark.sql.catalog import Column - spark = self.sparkSession + spark = self.spark spark.catalog._reset() spark.sql("CREATE DATABASE some_db") spark.sql("CREATE TABLE tab1 (name STRING, age INT)") @@ -1561,7 +1546,7 @@ class SQLTests(ReusedPySparkTestCase): lambda: spark.catalog.listColumns("does_not_exist")) def test_cache(self): - spark = self.sparkSession + spark = self.spark spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("tab1") spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("tab2") self.assertFalse(spark.catalog.isCached("tab1")) @@ -1605,7 +1590,7 @@ class HiveContextSQLTests(ReusedPySparkTestCase): cls.tearDownClass() raise unittest.SkipTest("Hive is not available") os.unlink(cls.tempdir.name) - cls.sqlCtx = HiveContext._createForTesting(cls.sc) + cls.spark = HiveContext._createForTesting(cls.sc) cls.testData = [Row(key=i, value=str(i)) for i in range(100)] cls.df = cls.sc.parallelize(cls.testData).toDF() @@ -1619,45 +1604,45 @@ class HiveContextSQLTests(ReusedPySparkTestCase): tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) df.write.saveAsTable("savedJsonTable", "json", "append", path=tmpPath) - actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath, "json") + actual = self.spark.createExternalTable("externalJsonTable", tmpPath, "json") self.assertEqual(sorted(df.collect()), - sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect())) self.assertEqual(sorted(df.collect()), - sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect())) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.sqlCtx.sql("DROP TABLE externalJsonTable") + self.spark.sql("DROP TABLE externalJsonTable") df.write.saveAsTable("savedJsonTable", "json", "overwrite", path=tmpPath) schema = StructType([StructField("value", StringType(), True)]) - actual = self.sqlCtx.createExternalTable("externalJsonTable", source="json", - schema=schema, path=tmpPath, - noUse="this options will not be used") + actual = self.spark.createExternalTable("externalJsonTable", source="json", + schema=schema, path=tmpPath, + noUse="this options will not be used") self.assertEqual(sorted(df.collect()), - sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect())) self.assertEqual(sorted(df.select("value").collect()), - sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect())) self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) - self.sqlCtx.sql("DROP TABLE savedJsonTable") - self.sqlCtx.sql("DROP TABLE externalJsonTable") + self.spark.sql("DROP TABLE savedJsonTable") + self.spark.sql("DROP TABLE externalJsonTable") - defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", - "org.apache.spark.sql.parquet") - self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + defaultDataSourceName = self.spark.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite") - actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath) + actual = self.spark.createExternalTable("externalJsonTable", path=tmpPath) self.assertEqual(sorted(df.collect()), - sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect())) self.assertEqual(sorted(df.collect()), - sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect())) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.sqlCtx.sql("DROP TABLE savedJsonTable") - self.sqlCtx.sql("DROP TABLE externalJsonTable") - self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + self.spark.sql("DROP TABLE savedJsonTable") + self.spark.sql("DROP TABLE externalJsonTable") + self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) shutil.rmtree(tmpPath) def test_window_functions(self): - df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) w = Window.partitionBy("value").orderBy("key") from pyspark.sql import functions as F sel = df.select(df.value, df.key, @@ -1679,7 +1664,7 @@ class HiveContextSQLTests(ReusedPySparkTestCase): self.assertEqual(tuple(r), ex[:len(r)]) def test_window_functions_without_partitionBy(self): - df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) w = Window.orderBy("key", df.value) from pyspark.sql import functions as F sel = df.select(df.value, df.key, @@ -1701,7 +1686,7 @@ class HiveContextSQLTests(ReusedPySparkTestCase): self.assertEqual(tuple(r), ex[:len(r)]) def test_collect_functions(self): - df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) from pyspark.sql import functions self.assertEqual( |