aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/tests.py
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-03-08 14:00:03 -0800
committerDavies Liu <davies.liu@gmail.com>2016-03-08 14:00:03 -0800
commitd57daf1f7732a7ac54a91fe112deeda0a254f9ef (patch)
treeab88e42e961763cadd75d95b9c9989b3c460bde6 /python/pyspark/sql/tests.py
parentd5ce61722f2aa4167a8344e1664b000c70d5a3f8 (diff)
downloadspark-d57daf1f7732a7ac54a91fe112deeda0a254f9ef.tar.gz
spark-d57daf1f7732a7ac54a91fe112deeda0a254f9ef.tar.bz2
spark-d57daf1f7732a7ac54a91fe112deeda0a254f9ef.zip
[SPARK-13593] [SQL] improve the `createDataFrame` to accept data type string and verify the data
## What changes were proposed in this pull request? This PR improves the `createDataFrame` method to make it also accept datatype string, then users can convert python RDD to DataFrame easily, for example, `df = rdd.toDF("a: int, b: string")`. It also supports flat schema so users can convert an RDD of int to DataFrame directly, we will automatically wrap int to row for users. If schema is given, now we checks if the real data matches the given schema, and throw error if it doesn't. ## How was this patch tested? new tests in `test.py` and doc test in `types.py` Author: Wenchen Fan <wenchen@databricks.com> Closes #11444 from cloud-fan/pyrdd.
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r--python/pyspark/sql/tests.py40
1 files changed, 37 insertions, 3 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index b5b848e1db..9722e9e9ca 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -369,9 +369,7 @@ class SQLTests(ReusedPySparkTestCase):
rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i))
schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())])
df = self.sqlCtx.createDataFrame(rdd, schema)
- message = ".*Input row doesn't have expected number of values required by the schema.*"
- with self.assertRaisesRegexp(Exception, message):
- df.show()
+ self.assertRaises(Exception, lambda: df.show())
def test_serialize_nested_array_and_map(self):
d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
@@ -1178,6 +1176,42 @@ class SQLTests(ReusedPySparkTestCase):
# planner should not crash without a join
broadcast(df1)._jdf.queryExecution().executedPlan()
+ def test_toDF_with_schema_string(self):
+ data = [Row(key=i, value=str(i)) for i in range(100)]
+ rdd = self.sc.parallelize(data, 5)
+
+ df = rdd.toDF("key: int, value: string")
+ self.assertEqual(df.schema.simpleString(), "struct<key:int,value:string>")
+ self.assertEqual(df.collect(), data)
+
+ # different but compatible field types can be used.
+ df = rdd.toDF("key: string, value: string")
+ self.assertEqual(df.schema.simpleString(), "struct<key:string,value:string>")
+ self.assertEqual(df.collect(), [Row(key=str(i), value=str(i)) for i in range(100)])
+
+ # field names can differ.
+ df = rdd.toDF(" a: int, b: string ")
+ self.assertEqual(df.schema.simpleString(), "struct<a:int,b:string>")
+ self.assertEqual(df.collect(), data)
+
+ # number of fields must match.
+ self.assertRaisesRegexp(Exception, "Length of object",
+ lambda: rdd.toDF("key: int").collect())
+
+ # field types mismatch will cause exception at runtime.
+ self.assertRaisesRegexp(Exception, "FloatType can not accept",
+ lambda: rdd.toDF("key: float, value: string").collect())
+
+ # flat schema values will be wrapped into row.
+ df = rdd.map(lambda row: row.key).toDF("int")
+ self.assertEqual(df.schema.simpleString(), "struct<value:int>")
+ self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
+
+ # users can use DataType directly instead of data type string.
+ df = rdd.map(lambda row: row.key).toDF(IntegerType())
+ self.assertEqual(df.schema.simpleString(), "struct<value:int>")
+ self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
+
class HiveContextSQLTests(ReusedPySparkTestCase):