aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r--python/pyspark/sql/tests.py40
1 files changed, 37 insertions, 3 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index b5b848e1db..9722e9e9ca 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -369,9 +369,7 @@ class SQLTests(ReusedPySparkTestCase):
rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i))
schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())])
df = self.sqlCtx.createDataFrame(rdd, schema)
- message = ".*Input row doesn't have expected number of values required by the schema.*"
- with self.assertRaisesRegexp(Exception, message):
- df.show()
+ self.assertRaises(Exception, lambda: df.show())
def test_serialize_nested_array_and_map(self):
d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
@@ -1178,6 +1176,42 @@ class SQLTests(ReusedPySparkTestCase):
# planner should not crash without a join
broadcast(df1)._jdf.queryExecution().executedPlan()
+ def test_toDF_with_schema_string(self):
+ data = [Row(key=i, value=str(i)) for i in range(100)]
+ rdd = self.sc.parallelize(data, 5)
+
+ df = rdd.toDF("key: int, value: string")
+ self.assertEqual(df.schema.simpleString(), "struct<key:int,value:string>")
+ self.assertEqual(df.collect(), data)
+
+ # different but compatible field types can be used.
+ df = rdd.toDF("key: string, value: string")
+ self.assertEqual(df.schema.simpleString(), "struct<key:string,value:string>")
+ self.assertEqual(df.collect(), [Row(key=str(i), value=str(i)) for i in range(100)])
+
+ # field names can differ.
+ df = rdd.toDF(" a: int, b: string ")
+ self.assertEqual(df.schema.simpleString(), "struct<a:int,b:string>")
+ self.assertEqual(df.collect(), data)
+
+ # number of fields must match.
+ self.assertRaisesRegexp(Exception, "Length of object",
+ lambda: rdd.toDF("key: int").collect())
+
+ # field types mismatch will cause exception at runtime.
+ self.assertRaisesRegexp(Exception, "FloatType can not accept",
+ lambda: rdd.toDF("key: float, value: string").collect())
+
+ # flat schema values will be wrapped into row.
+ df = rdd.map(lambda row: row.key).toDF("int")
+ self.assertEqual(df.schema.simpleString(), "struct<value:int>")
+ self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
+
+ # users can use DataType directly instead of data type string.
+ df = rdd.map(lambda row: row.key).toDF(IntegerType())
+ self.assertEqual(df.schema.simpleString(), "struct<value:int>")
+ self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
+
class HiveContextSQLTests(ReusedPySparkTestCase):