aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-01-27 16:08:24 -0800
committerReynold Xin <rxin@databricks.com>2015-01-27 16:08:24 -0800
commit119f45d61d7b48d376cca05e1b4f0c7fcf65bfa8 (patch)
tree714df6362313e93bee0e9dba2f84b3ba1697e555 /python/pyspark/tests.py
parentb1b35ca2e440df40b253bf967bb93705d355c1c0 (diff)
downloadspark-119f45d61d7b48d376cca05e1b4f0c7fcf65bfa8.tar.gz
spark-119f45d61d7b48d376cca05e1b4f0c7fcf65bfa8.tar.bz2
spark-119f45d61d7b48d376cca05e1b4f0c7fcf65bfa8.zip
[SPARK-5097][SQL] DataFrame
This pull request redesigns the existing Spark SQL dsl, which already provides data frame like functionalities. TODOs: With the exception of Python support, other tasks can be done in separate, follow-up PRs. - [ ] Audit of the API - [ ] Documentation - [ ] More test cases to cover the new API - [x] Python support - [ ] Type alias SchemaRDD Author: Reynold Xin <rxin@databricks.com> Author: Davies Liu <davies@databricks.com> Closes #4173 from rxin/df1 and squashes the following commits: 0a1a73b [Reynold Xin] Merge branch 'df1' of github.com:rxin/spark into df1 23b4427 [Reynold Xin] Mima. 828f70d [Reynold Xin] Merge pull request #7 from davies/df 257b9e6 [Davies Liu] add repartition 6bf2b73 [Davies Liu] fix collect with UDT and tests e971078 [Reynold Xin] Missing quotes. b9306b4 [Reynold Xin] Remove removeColumn/updateColumn for now. a728bf2 [Reynold Xin] Example rename. e8aa3d3 [Reynold Xin] groupby -> groupBy. 9662c9e [Davies Liu] improve DataFrame Python API 4ae51ea [Davies Liu] python API for dataframe 1e5e454 [Reynold Xin] Fixed a bug with symbol conversion. 2ca74db [Reynold Xin] Couple minor fixes. ea98ea1 [Reynold Xin] Documentation & literal expressions. 2b22684 [Reynold Xin] Got rid of IntelliJ problems. 02bbfbc [Reynold Xin] Tightening imports. ffbce66 [Reynold Xin] Fixed compilation error. 59b6d8b [Reynold Xin] Style violation. b85edfb [Reynold Xin] ALS. 8c37f0a [Reynold Xin] Made MLlib and examples compile 6d53134 [Reynold Xin] Hive module. d35efd5 [Reynold Xin] Fixed compilation error. ce4a5d2 [Reynold Xin] Fixed test cases in SQL except ParquetIOSuite. 66d5ef1 [Reynold Xin] SQLContext minor patch. c9bcdc0 [Reynold Xin] Checkpoint: SQL module compiles!
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py155
1 files changed, 86 insertions, 69 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index b474fcf5bf..e8e207af46 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -806,6 +806,9 @@ class SQLTests(ReusedPySparkTestCase):
def setUp(self):
self.sqlCtx = SQLContext(self.sc)
+ self.testData = [Row(key=i, value=str(i)) for i in range(100)]
+ rdd = self.sc.parallelize(self.testData)
+ self.df = self.sqlCtx.inferSchema(rdd)
def test_udf(self):
self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
@@ -821,7 +824,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_udf_with_array_type(self):
d = [Row(l=range(3), d={"key": range(5)})]
rdd = self.sc.parallelize(d)
- srdd = self.sqlCtx.inferSchema(rdd).registerTempTable("test")
+ self.sqlCtx.inferSchema(rdd).registerTempTable("test")
self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
[(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
@@ -839,68 +842,51 @@ class SQLTests(ReusedPySparkTestCase):
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
- srdd = self.sqlCtx.jsonRDD(rdd)
- srdd.count()
- srdd.collect()
- srdd.schemaString()
- srdd.schema()
+ df = self.sqlCtx.jsonRDD(rdd)
+ df.count()
+ df.collect()
+ df.schema()
# cache and checkpoint
- self.assertFalse(srdd.is_cached)
- srdd.persist()
- srdd.unpersist()
- srdd.cache()
- self.assertTrue(srdd.is_cached)
- self.assertFalse(srdd.isCheckpointed())
- self.assertEqual(None, srdd.getCheckpointFile())
-
- srdd = srdd.coalesce(2, True)
- srdd = srdd.repartition(3)
- srdd = srdd.distinct()
- srdd.intersection(srdd)
- self.assertEqual(2, srdd.count())
-
- srdd.registerTempTable("temp")
- srdd = self.sqlCtx.sql("select foo from temp")
- srdd.count()
- srdd.collect()
-
- def test_distinct(self):
- rdd = self.sc.parallelize(['{"a": 1}', '{"b": 2}', '{"c": 3}']*10, 10)
- srdd = self.sqlCtx.jsonRDD(rdd)
- self.assertEquals(srdd.getNumPartitions(), 10)
- self.assertEquals(srdd.distinct().count(), 3)
- result = srdd.distinct(5)
- self.assertEquals(result.getNumPartitions(), 5)
- self.assertEquals(result.count(), 3)
+ self.assertFalse(df.is_cached)
+ df.persist()
+ df.unpersist()
+ df.cache()
+ self.assertTrue(df.is_cached)
+ self.assertEqual(2, df.count())
+
+ df.registerTempTable("temp")
+ df = self.sqlCtx.sql("select foo from temp")
+ df.count()
+ df.collect()
def test_apply_schema_to_row(self):
- srdd = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
- srdd2 = self.sqlCtx.applySchema(srdd.map(lambda x: x), srdd.schema())
- self.assertEqual(srdd.collect(), srdd2.collect())
+ df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
+ df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema())
+ self.assertEqual(df.collect(), df2.collect())
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
- srdd3 = self.sqlCtx.applySchema(rdd, srdd.schema())
- self.assertEqual(10, srdd3.count())
+ df3 = self.sqlCtx.applySchema(rdd, df.schema())
+ self.assertEqual(10, df3.count())
def test_serialize_nested_array_and_map(self):
d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
rdd = self.sc.parallelize(d)
- srdd = self.sqlCtx.inferSchema(rdd)
- row = srdd.first()
+ df = self.sqlCtx.inferSchema(rdd)
+ row = df.head()
self.assertEqual(1, len(row.l))
self.assertEqual(1, row.l[0].a)
self.assertEqual("2", row.d["key"].d)
- l = srdd.map(lambda x: x.l).first()
+ l = df.map(lambda x: x.l).first()
self.assertEqual(1, len(l))
self.assertEqual('s', l[0].b)
- d = srdd.map(lambda x: x.d).first()
+ d = df.map(lambda x: x.d).first()
self.assertEqual(1, len(d))
self.assertEqual(1.0, d["key"].c)
- row = srdd.map(lambda x: x.d["key"]).first()
+ row = df.map(lambda x: x.d["key"]).first()
self.assertEqual(1.0, row.c)
self.assertEqual("2", row.d)
@@ -908,26 +894,26 @@ class SQLTests(ReusedPySparkTestCase):
d = [Row(l=[], d={}),
Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
rdd = self.sc.parallelize(d)
- srdd = self.sqlCtx.inferSchema(rdd)
- self.assertEqual([], srdd.map(lambda r: r.l).first())
- self.assertEqual([None, ""], srdd.map(lambda r: r.s).collect())
- srdd.registerTempTable("test")
+ df = self.sqlCtx.inferSchema(rdd)
+ self.assertEqual([], df.map(lambda r: r.l).first())
+ self.assertEqual([None, ""], df.map(lambda r: r.s).collect())
+ df.registerTempTable("test")
result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
- self.assertEqual(1, result.first()[0])
+ self.assertEqual(1, result.head()[0])
- srdd2 = self.sqlCtx.inferSchema(rdd, 1.0)
- self.assertEqual(srdd.schema(), srdd2.schema())
- self.assertEqual({}, srdd2.map(lambda r: r.d).first())
- self.assertEqual([None, ""], srdd2.map(lambda r: r.s).collect())
- srdd2.registerTempTable("test2")
+ df2 = self.sqlCtx.inferSchema(rdd, 1.0)
+ self.assertEqual(df.schema(), df2.schema())
+ self.assertEqual({}, df2.map(lambda r: r.d).first())
+ self.assertEqual([None, ""], df2.map(lambda r: r.s).collect())
+ df2.registerTempTable("test2")
result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
- self.assertEqual(1, result.first()[0])
+ self.assertEqual(1, result.head()[0])
def test_struct_in_map(self):
d = [Row(m={Row(i=1): Row(s="")})]
rdd = self.sc.parallelize(d)
- srdd = self.sqlCtx.inferSchema(rdd)
- k, v = srdd.first().m.items()[0]
+ df = self.sqlCtx.inferSchema(rdd)
+ k, v = df.head().m.items()[0]
self.assertEqual(1, k.i)
self.assertEqual("", v.s)
@@ -935,9 +921,9 @@ class SQLTests(ReusedPySparkTestCase):
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
self.assertEqual(1, row.asDict()['l'][0].a)
rdd = self.sc.parallelize([row])
- srdd = self.sqlCtx.inferSchema(rdd)
- srdd.registerTempTable("test")
- row = self.sqlCtx.sql("select l, d from test").first()
+ df = self.sqlCtx.inferSchema(rdd)
+ df.registerTempTable("test")
+ row = self.sqlCtx.sql("select l, d from test").head()
self.assertEqual(1, row.asDict()["l"][0].a)
self.assertEqual(1.0, row.asDict()['d']['key'].c)
@@ -945,12 +931,12 @@ class SQLTests(ReusedPySparkTestCase):
from pyspark.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
rdd = self.sc.parallelize([row])
- srdd = self.sqlCtx.inferSchema(rdd)
- schema = srdd.schema()
+ df = self.sqlCtx.inferSchema(rdd)
+ schema = df.schema()
field = [f for f in schema.fields if f.name == "point"][0]
self.assertEqual(type(field.dataType), ExamplePointUDT)
- srdd.registerTempTable("labeled_point")
- point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point
+ df.registerTempTable("labeled_point")
+ point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
self.assertEqual(point, ExamplePoint(1.0, 2.0))
def test_apply_schema_with_udt(self):
@@ -959,21 +945,52 @@ class SQLTests(ReusedPySparkTestCase):
rdd = self.sc.parallelize([row])
schema = StructType([StructField("label", DoubleType(), False),
StructField("point", ExamplePointUDT(), False)])
- srdd = self.sqlCtx.applySchema(rdd, schema)
- point = srdd.first().point
+ df = self.sqlCtx.applySchema(rdd, schema)
+ point = df.head().point
self.assertEquals(point, ExamplePoint(1.0, 2.0))
def test_parquet_with_udt(self):
from pyspark.tests import ExamplePoint
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
rdd = self.sc.parallelize([row])
- srdd0 = self.sqlCtx.inferSchema(rdd)
+ df0 = self.sqlCtx.inferSchema(rdd)
output_dir = os.path.join(self.tempdir.name, "labeled_point")
- srdd0.saveAsParquetFile(output_dir)
- srdd1 = self.sqlCtx.parquetFile(output_dir)
- point = srdd1.first().point
+ df0.saveAsParquetFile(output_dir)
+ df1 = self.sqlCtx.parquetFile(output_dir)
+ point = df1.head().point
self.assertEquals(point, ExamplePoint(1.0, 2.0))
+ def test_column_operators(self):
+ from pyspark.sql import Column, LongType
+ ci = self.df.key
+ cs = self.df.value
+ c = ci == cs
+ self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
+ rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci)
+ self.assertTrue(all(isinstance(c, Column) for c in rcc))
+ cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs]
+ self.assertTrue(all(isinstance(c, Column) for c in cb))
+ cbit = (ci & ci), (ci | ci), (ci ^ ci), (~ci)
+ self.assertTrue(all(isinstance(c, Column) for c in cbit))
+ css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a')
+ self.assertTrue(all(isinstance(c, Column) for c in css))
+ self.assertTrue(isinstance(ci.cast(LongType()), Column))
+
+ def test_column_select(self):
+ df = self.df
+ self.assertEqual(self.testData, df.select("*").collect())
+ self.assertEqual(self.testData, df.select(df.key, df.value).collect())
+ self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
+
+ def test_aggregator(self):
+ df = self.df
+ g = df.groupBy()
+ self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
+ self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
+ # TODO(davies): fix aggregators
+ from pyspark.sql import Aggregator as Agg
+ # self.assertEqual((0, '100'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first()))
+
class InputFormatTests(ReusedPySparkTestCase):