aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r--python/pyspark/sql/tests.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 87dbb50495..520b09d9c6 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -411,6 +411,22 @@ class SQLTests(ReusedPySparkTestCase):
df3 = self.spark.createDataFrame(rdd, df.schema)
self.assertEqual(10, df3.count())
+ def test_apply_schema_to_dict_and_rows(self):
+ schema = StructType().add("b", StringType()).add("a", IntegerType())
+ input = [{"a": 1}, {"b": "coffee"}]
+ rdd = self.sc.parallelize(input)
+ for verify in [False, True]:
+ df = self.spark.createDataFrame(input, schema, verifySchema=verify)
+ df2 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
+ self.assertEqual(df.schema, df2.schema)
+
+ rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
+ df3 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
+ self.assertEqual(10, df3.count())
+ input = [Row(a=x, b=str(x)) for x in range(10)]
+ df4 = self.spark.createDataFrame(input, schema, verifySchema=verify)
+ self.assertEqual(10, df4.count())
+
def test_create_dataframe_schema_mismatch(self):
input = [Row(a=1)]
rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i))