diff options
author | Franklyn D'souza <franklynd@gmail.com> | 2016-02-21 16:58:17 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-02-21 16:58:17 -0800 |
commit | 0f90f4e6ac9e9ca694e3622b866f33d3fdf1a459 (patch) | |
tree | 776597399768afa7ce49e2726f156d4e0c125c1c /python | |
parent | 0cbadf28c99721ba1ac22ac57beef9df998ea685 (diff) | |
download | spark-0f90f4e6ac9e9ca694e3622b866f33d3fdf1a459.tar.gz spark-0f90f4e6ac9e9ca694e3622b866f33d3fdf1a459.tar.bz2 spark-0f90f4e6ac9e9ca694e3622b866f33d3fdf1a459.zip |
[SPARK-13410][SQL] Support unionAll for DataFrames with UDT columns.
## What changes were proposed in this pull request?
This PR adds equality operators to UDT classes so that they can be correctly tested for dataType equality during union operations.
This was previously causing `"AnalysisException: u"unresolved operator 'Union;""` when trying to unionAll two dataframes with UDT columns as below.
```
from pyspark.sql.tests import PythonOnlyPoint, PythonOnlyUDT
from pyspark.sql import types
schema = types.StructType([types.StructField("point", PythonOnlyUDT(), True)])
a = sqlCtx.createDataFrame([[PythonOnlyPoint(1.0, 2.0)]], schema)
b = sqlCtx.createDataFrame([[PythonOnlyPoint(3.0, 4.0)]], schema)
c = a.unionAll(b)
```
## How was the this patch tested?
Tested using two unit tests in sql/test.py and the DataFrameSuite.
Additional information here : https://issues.apache.org/jira/browse/SPARK-13410
Author: Franklyn D'souza <franklynd@gmail.com>
Closes #11279 from damnMeddlingKid/udt-union-all.
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/sql/tests.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e30aa0a796..cc11c0f35c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -601,6 +601,24 @@ class SQLTests(ReusedPySparkTestCase): point = df1.head().point self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) + def test_unionAll_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + row1 = (1.0, ExamplePoint(1.0, 2.0)) + row2 = (2.0, ExamplePoint(3.0, 4.0)) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + df1 = self.sqlCtx.createDataFrame([row1], schema) + df2 = self.sqlCtx.createDataFrame([row2], schema) + + result = df1.unionAll(df2).orderBy("label").collect() + self.assertEqual( + result, + [ + Row(label=1.0, point=ExamplePoint(1.0, 2.0)), + Row(label=2.0, point=ExamplePoint(3.0, 4.0)) + ] + ) + def test_column_operators(self): ci = self.df.key cs = self.df.value |