aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2017-02-04 15:57:56 -0800
committergatorsmile <gatorsmile@gmail.com>2017-02-04 15:57:56 -0800
commit0674e7eb85160e3f8da333b5243d76063824d58c (patch)
tree6ce795f0228a96b87364009644e25c759064e0db
parent2f3c20bbddd266015d9478c35ce2b37d67e01200 (diff)
downloadspark-0674e7eb85160e3f8da333b5243d76063824d58c.tar.gz
spark-0674e7eb85160e3f8da333b5243d76063824d58c.tar.bz2
spark-0674e7eb85160e3f8da333b5243d76063824d58c.zip
[SPARK-19425][SQL] Make ExtractEquiJoinKeys support UDT columns
## What changes were proposed in this pull request? DataFrame.except doesn't work for UDT columns. It is because `ExtractEquiJoinKeys` will run `Literal.default` against UDT. However, we don't handle UDT in `Literal.default` and an exception will throw like: java.lang.RuntimeException: no default for type org.apache.spark.ml.linalg.VectorUDT3bfc3ba7 at org.apache.spark.sql.catalyst.expressions.Literal$.default(literals.scala:179) at org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys$$anonfun$4.apply(patterns.scala:117) at org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys$$anonfun$4.apply(patterns.scala:110) More simple fix is just let `Literal.default` handle UDT by its sql type. So we can use more efficient join type on UDT. Besides `except`, this also fixes other similar scenarios, so in summary this fixes: * `except` on two Datasets with UDT * `intersect` on two Datasets with UDT * `Join` with the join conditions using `<=>` on UDT columns ## How was this patch tested? Jenkins tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #16765 from viirya/df-except-for-udt.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala1
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala9
3 files changed, 13 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index cb0c4d333b..e66fb89339 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -175,6 +175,7 @@ object Literal {
case map: MapType => create(Map(), map)
case struct: StructType =>
create(InternalRow.fromSeq(struct.fields.map(f => default(f.dataType).value)), struct)
+ case udt: UserDefinedType[_] => default(udt.sqlType)
case other =>
throw new RuntimeException(s"no default for type $dataType")
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
index 4af4da8a9f..15e8e6c057 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
@@ -22,6 +22,7 @@ import java.nio.charset.StandardCharsets
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
@@ -67,6 +68,8 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Literal.default(ArrayType(StringType)), Array())
checkEvaluation(Literal.default(MapType(IntegerType, StringType)), Map())
checkEvaluation(Literal.default(StructType(StructField("a", StringType) :: Nil)), Row(""))
+ // ExamplePointUDT.sqlType is ArrayType(DoubleType, false).
+ checkEvaluation(Literal.default(new ExamplePointUDT), Array())
}
test("boolean literals") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index ea4a8ee7ff..c7a77daaca 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -150,6 +150,10 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
MyLabeledPoint(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))),
MyLabeledPoint(0.0, new UDT.MyDenseVector(Array(0.2, 2.0)))).toDF()
+ private lazy val pointsRDD2 = Seq(
+ MyLabeledPoint(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))),
+ MyLabeledPoint(0.0, new UDT.MyDenseVector(Array(0.3, 3.0)))).toDF()
+
test("register user type: MyDenseVector for MyLabeledPoint") {
val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v }
val labelsArrays: Array[Double] = labels.collect()
@@ -297,4 +301,9 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
sql("SELECT doOtherUDF(doSubTypeUDF(42))")
}
+ test("except on UDT") {
+ checkAnswer(
+ pointsRDD.except(pointsRDD2),
+ Seq(Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0)))))
+ }
}