aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py19
1 files changed, 19 insertions, 0 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 253a471849..68fd756876 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -796,6 +796,25 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(1.0, row.c)
self.assertEqual("2", row.d)
+ def test_infer_schema(self):
+ d = [Row(l=[], d={}),
+ Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
+ rdd = self.sc.parallelize(d)
+ srdd = self.sqlCtx.inferSchema(rdd)
+ self.assertEqual([], srdd.map(lambda r: r.l).first())
+ self.assertEqual([None, ""], srdd.map(lambda r: r.s).collect())
+ srdd.registerTempTable("test")
+ result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
+ self.assertEqual(1, result.first()[0])
+
+ srdd2 = self.sqlCtx.inferSchema(rdd, 1.0)
+ self.assertEqual(srdd.schema(), srdd2.schema())
+ self.assertEqual({}, srdd2.map(lambda r: r.d).first())
+ self.assertEqual([None, ""], srdd2.map(lambda r: r.s).collect())
+ srdd2.registerTempTable("test2")
+ result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
+ self.assertEqual(1, result.first()[0])
+
def test_convert_row_to_dict(self):
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
self.assertEqual(1, row.asDict()['l'][0].a)