aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/tests.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-02-09 20:49:22 -0800
committerReynold Xin <rxin@databricks.com>2015-02-09 20:49:22 -0800
commit08488c175f2e8532cb6aab84da2abd9ad57179cc (patch)
tree05b7076b329fd8710ab650f91c38af3bf1a6a2d9 /python/pyspark/sql/tests.py
parentd302c4800bf2f74eceb731169ddf1766136b7398 (diff)
downloadspark-08488c175f2e8532cb6aab84da2abd9ad57179cc.tar.gz
spark-08488c175f2e8532cb6aab84da2abd9ad57179cc.tar.bz2
spark-08488c175f2e8532cb6aab84da2abd9ad57179cc.zip
[SPARK-5469] restructure pyspark.sql into multiple files
All the DataTypes moved into pyspark.sql.types The changes can be tracked by `--find-copies-harder -M25` ``` davieslocalhost:~/work/spark/python$ git diff --find-copies-harder -M25 --numstat master.. 2 5 python/docs/pyspark.ml.rst 0 3 python/docs/pyspark.mllib.rst 10 2 python/docs/pyspark.sql.rst 1 1 python/pyspark/mllib/linalg.py 21 14 python/pyspark/{mllib => sql}/__init__.py 14 2108 python/pyspark/{sql.py => sql/context.py} 10 1772 python/pyspark/{sql.py => sql/dataframe.py} 7 6 python/pyspark/{sql_tests.py => sql/tests.py} 8 1465 python/pyspark/{sql.py => sql/types.py} 4 2 python/run-tests 1 1 sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala ``` Also `git blame -C -C python/pyspark/sql/context.py` to track the history. Author: Davies Liu <davies@databricks.com> Closes #4479 from davies/sql and squashes the following commits: 1b5f0a5 [Davies Liu] Merge branch 'master' of github.com:apache/spark into sql 2b2b983 [Davies Liu] restructure pyspark.sql
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r--python/pyspark/sql/tests.py300
1 files changed, 300 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
new file mode 100644
index 0000000000..d25c6365ed
--- /dev/null
+++ b/python/pyspark/sql/tests.py
@@ -0,0 +1,300 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Unit tests for pyspark.sql; additional tests are implemented as doctests in
+individual modules.
+"""
+import os
+import sys
+import pydoc
+import shutil
+import tempfile
+
+if sys.version_info[:2] <= (2, 6):
+ try:
+ import unittest2 as unittest
+ except ImportError:
+ sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
+ sys.exit(1)
+else:
+ import unittest
+
+
+from pyspark.sql import SQLContext, Column
+from pyspark.sql.types import IntegerType, Row, ArrayType, StructType, StructField, \
+ UserDefinedType, DoubleType, LongType
+from pyspark.tests import ReusedPySparkTestCase
+
+
+class ExamplePointUDT(UserDefinedType):
+ """
+ User-defined type (UDT) for ExamplePoint.
+ """
+
+ @classmethod
+ def sqlType(self):
+ return ArrayType(DoubleType(), False)
+
+ @classmethod
+ def module(cls):
+ return 'pyspark.tests'
+
+ @classmethod
+ def scalaUDT(cls):
+ return 'org.apache.spark.sql.test.ExamplePointUDT'
+
+ def serialize(self, obj):
+ return [obj.x, obj.y]
+
+ def deserialize(self, datum):
+ return ExamplePoint(datum[0], datum[1])
+
+
+class ExamplePoint:
+ """
+ An example class to demonstrate UDT in Scala, Java, and Python.
+ """
+
+ __UDT__ = ExamplePointUDT()
+
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def __repr__(self):
+ return "ExamplePoint(%s,%s)" % (self.x, self.y)
+
+ def __str__(self):
+ return "(%s,%s)" % (self.x, self.y)
+
+ def __eq__(self, other):
+ return isinstance(other, ExamplePoint) and \
+ other.x == self.x and other.y == self.y
+
+
+class SQLTests(ReusedPySparkTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ ReusedPySparkTestCase.setUpClass()
+ cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+ os.unlink(cls.tempdir.name)
+ cls.sqlCtx = SQLContext(cls.sc)
+ cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
+ rdd = cls.sc.parallelize(cls.testData)
+ cls.df = cls.sqlCtx.inferSchema(rdd)
+
+ @classmethod
+ def tearDownClass(cls):
+ ReusedPySparkTestCase.tearDownClass()
+ shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+
+ def test_udf(self):
+ self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
+ [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
+ self.assertEqual(row[0], 5)
+
+ def test_udf2(self):
+ self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType())
+ self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
+ [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
+ self.assertEqual(4, res[0])
+
+ def test_udf_with_array_type(self):
+ d = [Row(l=range(3), d={"key": range(5)})]
+ rdd = self.sc.parallelize(d)
+ self.sqlCtx.inferSchema(rdd).registerTempTable("test")
+ self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
+ self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
+ [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
+ self.assertEqual(range(3), l1)
+ self.assertEqual(1, l2)
+
+ def test_broadcast_in_udf(self):
+ bar = {"a": "aa", "b": "bb", "c": "abc"}
+ foo = self.sc.broadcast(bar)
+ self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
+ [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect()
+ self.assertEqual("abc", res[0])
+ [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
+ self.assertEqual("", res[0])
+
+ def test_basic_functions(self):
+ rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
+ df = self.sqlCtx.jsonRDD(rdd)
+ df.count()
+ df.collect()
+ df.schema()
+
+ # cache and checkpoint
+ self.assertFalse(df.is_cached)
+ df.persist()
+ df.unpersist()
+ df.cache()
+ self.assertTrue(df.is_cached)
+ self.assertEqual(2, df.count())
+
+ df.registerTempTable("temp")
+ df = self.sqlCtx.sql("select foo from temp")
+ df.count()
+ df.collect()
+
+ def test_apply_schema_to_row(self):
+ df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
+ df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema())
+ self.assertEqual(df.collect(), df2.collect())
+
+ rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
+ df3 = self.sqlCtx.applySchema(rdd, df.schema())
+ self.assertEqual(10, df3.count())
+
+ def test_serialize_nested_array_and_map(self):
+ d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
+ rdd = self.sc.parallelize(d)
+ df = self.sqlCtx.inferSchema(rdd)
+ row = df.head()
+ self.assertEqual(1, len(row.l))
+ self.assertEqual(1, row.l[0].a)
+ self.assertEqual("2", row.d["key"].d)
+
+ l = df.map(lambda x: x.l).first()
+ self.assertEqual(1, len(l))
+ self.assertEqual('s', l[0].b)
+
+ d = df.map(lambda x: x.d).first()
+ self.assertEqual(1, len(d))
+ self.assertEqual(1.0, d["key"].c)
+
+ row = df.map(lambda x: x.d["key"]).first()
+ self.assertEqual(1.0, row.c)
+ self.assertEqual("2", row.d)
+
+ def test_infer_schema(self):
+ d = [Row(l=[], d={}),
+ Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
+ rdd = self.sc.parallelize(d)
+ df = self.sqlCtx.inferSchema(rdd)
+ self.assertEqual([], df.map(lambda r: r.l).first())
+ self.assertEqual([None, ""], df.map(lambda r: r.s).collect())
+ df.registerTempTable("test")
+ result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
+ self.assertEqual(1, result.head()[0])
+
+ df2 = self.sqlCtx.inferSchema(rdd, 1.0)
+ self.assertEqual(df.schema(), df2.schema())
+ self.assertEqual({}, df2.map(lambda r: r.d).first())
+ self.assertEqual([None, ""], df2.map(lambda r: r.s).collect())
+ df2.registerTempTable("test2")
+ result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
+ self.assertEqual(1, result.head()[0])
+
+ def test_struct_in_map(self):
+ d = [Row(m={Row(i=1): Row(s="")})]
+ rdd = self.sc.parallelize(d)
+ df = self.sqlCtx.inferSchema(rdd)
+ k, v = df.head().m.items()[0]
+ self.assertEqual(1, k.i)
+ self.assertEqual("", v.s)
+
+ def test_convert_row_to_dict(self):
+ row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
+ self.assertEqual(1, row.asDict()['l'][0].a)
+ rdd = self.sc.parallelize([row])
+ df = self.sqlCtx.inferSchema(rdd)
+ df.registerTempTable("test")
+ row = self.sqlCtx.sql("select l, d from test").head()
+ self.assertEqual(1, row.asDict()["l"][0].a)
+ self.assertEqual(1.0, row.asDict()['d']['key'].c)
+
+ def test_infer_schema_with_udt(self):
+ from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
+ row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+ rdd = self.sc.parallelize([row])
+ df = self.sqlCtx.inferSchema(rdd)
+ schema = df.schema()
+ field = [f for f in schema.fields if f.name == "point"][0]
+ self.assertEqual(type(field.dataType), ExamplePointUDT)
+ df.registerTempTable("labeled_point")
+ point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
+ self.assertEqual(point, ExamplePoint(1.0, 2.0))
+
+ def test_apply_schema_with_udt(self):
+ from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
+ row = (1.0, ExamplePoint(1.0, 2.0))
+ rdd = self.sc.parallelize([row])
+ schema = StructType([StructField("label", DoubleType(), False),
+ StructField("point", ExamplePointUDT(), False)])
+ df = self.sqlCtx.applySchema(rdd, schema)
+ point = df.head().point
+ self.assertEquals(point, ExamplePoint(1.0, 2.0))
+
+ def test_parquet_with_udt(self):
+ from pyspark.sql.tests import ExamplePoint
+ row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+ rdd = self.sc.parallelize([row])
+ df0 = self.sqlCtx.inferSchema(rdd)
+ output_dir = os.path.join(self.tempdir.name, "labeled_point")
+ df0.saveAsParquetFile(output_dir)
+ df1 = self.sqlCtx.parquetFile(output_dir)
+ point = df1.head().point
+ self.assertEquals(point, ExamplePoint(1.0, 2.0))
+
+ def test_column_operators(self):
+ ci = self.df.key
+ cs = self.df.value
+ c = ci == cs
+ self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
+ rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci)
+ self.assertTrue(all(isinstance(c, Column) for c in rcc))
+ cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs]
+ self.assertTrue(all(isinstance(c, Column) for c in cb))
+ cbool = (ci & ci), (ci | ci), (~ci)
+ self.assertTrue(all(isinstance(c, Column) for c in cbool))
+ css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a')
+ self.assertTrue(all(isinstance(c, Column) for c in css))
+ self.assertTrue(isinstance(ci.cast(LongType()), Column))
+
+ def test_column_select(self):
+ df = self.df
+ self.assertEqual(self.testData, df.select("*").collect())
+ self.assertEqual(self.testData, df.select(df.key, df.value).collect())
+ self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
+
+ def test_aggregator(self):
+ df = self.df
+ g = df.groupBy()
+ self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
+ self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
+
+ from pyspark.sql import Dsl
+ self.assertEqual((0, u'99'), tuple(g.agg(Dsl.first(df.key), Dsl.last(df.value)).first()))
+ self.assertTrue(95 < g.agg(Dsl.approxCountDistinct(df.key)).first()[0])
+ self.assertEqual(100, g.agg(Dsl.countDistinct(df.value)).first()[0])
+
+ def test_help_command(self):
+ # Regression test for SPARK-5464
+ rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
+ df = self.sqlCtx.jsonRDD(rdd)
+ # render_doc() reproduces the help() exception without printing output
+ pydoc.render_doc(df)
+ pydoc.render_doc(df.foo)
+ pydoc.render_doc(df.take(1))
+
+
+if __name__ == "__main__":
+ unittest.main()