aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala1
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala9
3 files changed, 13 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index cb0c4d333b..e66fb89339 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -175,6 +175,7 @@ object Literal {
case map: MapType => create(Map(), map)
case struct: StructType =>
create(InternalRow.fromSeq(struct.fields.map(f => default(f.dataType).value)), struct)
+ case udt: UserDefinedType[_] => default(udt.sqlType)
case other =>
throw new RuntimeException(s"no default for type $dataType")
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
index 4af4da8a9f..15e8e6c057 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
@@ -22,6 +22,7 @@ import java.nio.charset.StandardCharsets
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
@@ -67,6 +68,8 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Literal.default(ArrayType(StringType)), Array())
checkEvaluation(Literal.default(MapType(IntegerType, StringType)), Map())
checkEvaluation(Literal.default(StructType(StructField("a", StringType) :: Nil)), Row(""))
+ // ExamplePointUDT.sqlType is ArrayType(DoubleType, false).
+ checkEvaluation(Literal.default(new ExamplePointUDT), Array())
}
test("boolean literals") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index ea4a8ee7ff..c7a77daaca 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -150,6 +150,10 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
MyLabeledPoint(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))),
MyLabeledPoint(0.0, new UDT.MyDenseVector(Array(0.2, 2.0)))).toDF()
+ private lazy val pointsRDD2 = Seq(
+ MyLabeledPoint(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))),
+ MyLabeledPoint(0.0, new UDT.MyDenseVector(Array(0.3, 3.0)))).toDF()
+
test("register user type: MyDenseVector for MyLabeledPoint") {
val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v }
val labelsArrays: Array[Double] = labels.collect()
@@ -297,4 +301,9 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
sql("SELECT doOtherUDF(doSubTypeUDF(42))")
}
+ test("except on UDT") {
+ checkAnswer(
+ pointsRDD.except(pointsRDD2),
+ Seq(Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0)))))
+ }
}