aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r--python/pyspark/sql/tests.py68
1 files changed, 65 insertions, 3 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 39071e7e35..83899ad4b1 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -36,9 +36,9 @@ if sys.version_info[:2] <= (2, 6):
else:
import unittest
-from pyspark.sql import SQLContext, HiveContext, Column
-from pyspark.sql.types import IntegerType, Row, ArrayType, StructType, StructField, \
- UserDefinedType, DoubleType, LongType, StringType, _infer_type
+from pyspark.sql import SQLContext, HiveContext, Column, Row
+from pyspark.sql.types import *
+from pyspark.sql.types import UserDefinedType, _infer_type
from pyspark.tests import ReusedPySparkTestCase
@@ -204,6 +204,68 @@ class SQLTests(ReusedPySparkTestCase):
result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
self.assertEqual(1, result.head()[0])
+ def test_infer_nested_schema(self):
+ NestedRow = Row("f1", "f2")
+ nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}),
+ NestedRow([2, 3], {"row2": 2.0})])
+ df = self.sqlCtx.inferSchema(nestedRdd1)
+ self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0])
+
+ nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]),
+ NestedRow([[2, 3], [3, 4]], [2, 3])])
+ df = self.sqlCtx.inferSchema(nestedRdd2)
+ self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0])
+
+ from collections import namedtuple
+ CustomRow = namedtuple('CustomRow', 'field1 field2')
+ rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"),
+ CustomRow(field1=2, field2="row2"),
+ CustomRow(field1=3, field2="row3")])
+ df = self.sqlCtx.inferSchema(rdd)
+ self.assertEquals(Row(field1=1, field2=u'row1'), df.first())
+
+ def test_apply_schema(self):
+ from datetime import date, datetime
+ rdd = self.sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0,
+ date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1),
+ {"a": 1}, (2,), [1, 2, 3], None)])
+ schema = StructType([
+ StructField("byte1", ByteType(), False),
+ StructField("byte2", ByteType(), False),
+ StructField("short1", ShortType(), False),
+ StructField("short2", ShortType(), False),
+ StructField("int1", IntegerType(), False),
+ StructField("float1", FloatType(), False),
+ StructField("date1", DateType(), False),
+ StructField("time1", TimestampType(), False),
+ StructField("map1", MapType(StringType(), IntegerType(), False), False),
+ StructField("struct1", StructType([StructField("b", ShortType(), False)]), False),
+ StructField("list1", ArrayType(ByteType(), False), False),
+ StructField("null1", DoubleType(), True)])
+ df = self.sqlCtx.applySchema(rdd, schema)
+ results = df.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1,
+ x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
+ r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1),
+ datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
+ self.assertEqual(r, results.first())
+
+ df.registerTempTable("table2")
+ r = self.sqlCtx.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
+ "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " +
+ "float1 + 1.5 as float1 FROM table2").first()
+
+ self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r))
+
+ from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type
+ rdd = self.sc.parallelize([(127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1),
+ {"a": 1}, (2,), [1, 2, 3])])
+ abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]"
+ schema = _parse_schema_abstract(abstract)
+ typedSchema = _infer_schema_type(rdd.first(), schema)
+ df = self.sqlCtx.applySchema(rdd, typedSchema)
+ r = (127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), {"a": 1}, Row(b=2), [1, 2, 3])
+ self.assertEqual(r, tuple(df.first()))
+
def test_struct_in_map(self):
d = [Row(m={Row(i=1): Row(s="")})]
df = self.sc.parallelize(d).toDF()