aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/tests.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-08-15 12:41:27 -0700
committerJosh Rosen <joshrosen@databricks.com>2016-08-15 12:41:27 -0700
commitfffb0c0d19a2444e7554dfe6b27de0c086112b17 (patch)
treedc2fc14a9820672633b61b6acdf4a3d76985caf1 /python/pyspark/sql/tests.py
parent5da6c4b24f512b63cd4e6ba7dd8968066a9396f5 (diff)
downloadspark-fffb0c0d19a2444e7554dfe6b27de0c086112b17.tar.gz
spark-fffb0c0d19a2444e7554dfe6b27de0c086112b17.tar.bz2
spark-fffb0c0d19a2444e7554dfe6b27de0c086112b17.zip
[SPARK-16700][PYSPARK][SQL] create DataFrame from dict/Row with schema
## What changes were proposed in this pull request? In 2.0, we verify the data type against schema for every row for safety, but with performance cost, this PR make it optional. When we verify the data type for StructType, it does not support all the types we support in infer schema (for example, dict), this PR fix that to make them consistent. For Row object which is created using named arguments, the order of fields are sorted by name, they may be not different than the order in provided schema, this PR fix that by ignore the order of fields in this case. ## How was this patch tested? Created regression tests for them. Author: Davies Liu <davies@databricks.com> Closes #14469 from davies/py_dict.
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r--python/pyspark/sql/tests.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 87dbb50495..520b09d9c6 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -411,6 +411,22 @@ class SQLTests(ReusedPySparkTestCase):
df3 = self.spark.createDataFrame(rdd, df.schema)
self.assertEqual(10, df3.count())
+ def test_apply_schema_to_dict_and_rows(self):
+ schema = StructType().add("b", StringType()).add("a", IntegerType())
+ input = [{"a": 1}, {"b": "coffee"}]
+ rdd = self.sc.parallelize(input)
+ for verify in [False, True]:
+ df = self.spark.createDataFrame(input, schema, verifySchema=verify)
+ df2 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
+ self.assertEqual(df.schema, df2.schema)
+
+ rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
+ df3 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
+ self.assertEqual(10, df3.count())
+ input = [Row(a=x, b=str(x)) for x in range(10)]
+ df4 = self.spark.createDataFrame(input, schema, verifySchema=verify)
+ self.assertEqual(10, df4.count())
+
def test_create_dataframe_schema_mismatch(self):
input = [Row(a=1)]
rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i))