From 0f90f4e6ac9e9ca694e3622b866f33d3fdf1a459 Mon Sep 17 00:00:00 2001 From: Franklyn D'souza Date: Sun, 21 Feb 2016 16:58:17 -0800 Subject: [SPARK-13410][SQL] Support unionAll for DataFrames with UDT columns. ## What changes were proposed in this pull request? This PR adds equality operators to UDT classes so that they can be correctly tested for dataType equality during union operations. This was previously causing `"AnalysisException: u"unresolved operator 'Union;""` when trying to unionAll two dataframes with UDT columns as below. ``` from pyspark.sql.tests import PythonOnlyPoint, PythonOnlyUDT from pyspark.sql import types schema = types.StructType([types.StructField("point", PythonOnlyUDT(), True)]) a = sqlCtx.createDataFrame([[PythonOnlyPoint(1.0, 2.0)]], schema) b = sqlCtx.createDataFrame([[PythonOnlyPoint(3.0, 4.0)]], schema) c = a.unionAll(b) ``` ## How was the this patch tested? Tested using two unit tests in sql/test.py and the DataFrameSuite. Additional information here : https://issues.apache.org/jira/browse/SPARK-13410 Author: Franklyn D'souza Closes #11279 from damnMeddlingKid/udt-union-all. --- python/pyspark/sql/tests.py | 18 ++++++++++++++++++ .../org/apache/spark/sql/types/UserDefinedType.scala | 10 ++++++++++ .../org/apache/spark/sql/test/ExamplePointUDT.scala | 7 ++++++- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 16 ++++++++++++++++ 4 files changed, 50 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e30aa0a796..cc11c0f35c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -601,6 +601,24 @@ class SQLTests(ReusedPySparkTestCase): point = df1.head().point self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) + def test_unionAll_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + row1 = (1.0, ExamplePoint(1.0, 2.0)) + row2 = (2.0, ExamplePoint(3.0, 4.0)) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + df1 = self.sqlCtx.createDataFrame([row1], schema) + df2 = self.sqlCtx.createDataFrame([row2], schema) + + result = df1.unionAll(df2).orderBy("label").collect() + self.assertEqual( + result, + [ + Row(label=1.0, point=ExamplePoint(1.0, 2.0)), + Row(label=2.0, point=ExamplePoint(3.0, 4.0)) + ] + ) + def test_column_operators(self): ci = self.df.key cs = self.df.value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index d7a2c23be8..7664c30ee7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -86,6 +86,11 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { this.getClass == dataType.getClass override def sql: String = sqlType.sql + + override def equals(other: Any): Boolean = other match { + case that: UserDefinedType[_] => this.acceptsType(that) + case _ => false + } } /** @@ -112,4 +117,9 @@ private[sql] class PythonUserDefinedType( ("serializedClass" -> serializedPyClass) ~ ("sqlType" -> sqlType.jsonValue) } + + override def equals(other: Any): Boolean = other match { + case that: PythonUserDefinedType => this.pyUDT.equals(that.pyUDT) + case _ => false + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index 20a17ba82b..e2c9fc421b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -26,7 +26,12 @@ import org.apache.spark.sql.types._ * @param y y coordinate */ @SQLUserDefinedType(udt = classOf[ExamplePointUDT]) -private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable +private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable { + override def equals(other: Any): Boolean = other match { + case that: ExamplePoint => this.x == that.x && this.y == that.y + case _ => false + } +} /** * User-defined type for [[ExamplePoint]]. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 50a246489e..4930c485da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -112,6 +112,22 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { ) } + test("unionAll should union DataFrames with UDTs (SPARK-13410)") { + val rowRDD1 = sparkContext.parallelize(Seq(Row(1, new ExamplePoint(1.0, 2.0)))) + val schema1 = StructType(Array(StructField("label", IntegerType, false), + StructField("point", new ExamplePointUDT(), false))) + val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0, 4.0)))) + val schema2 = StructType(Array(StructField("label", IntegerType, false), + StructField("point", new ExamplePointUDT(), false))) + val df1 = sqlContext.createDataFrame(rowRDD1, schema1) + val df2 = sqlContext.createDataFrame(rowRDD2, schema2) + + checkAnswer( + df1.unionAll(df2).orderBy("label"), + Seq(Row(1, new ExamplePoint(1.0, 2.0)), Row(2, new ExamplePoint(3.0, 4.0))) + ) + } + test("empty data frame") { assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) assert(sqlContext.emptyDataFrame.count() === 0) -- cgit v1.2.3