aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala1
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala3
-rw-r--r--sql/hive/src/test/resources/golden/Cast Timestamp to Timestamp in UDF-0-66952a3949d7544716fd1a675498b1fa1
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala11
4 files changed, 14 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 0ad2b30cf9..0379275121 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -245,6 +245,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
}
private[this] lazy val cast: Any => Any = dataType match {
+ case dt if dt == child.dataType => identity[Any]
case StringType => castToString
case BinaryType => castToBinary
case DecimalType => castToDecimal
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index 7d1ad53d8b..7cda0dd302 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -51,12 +51,13 @@ private[hive] abstract class HiveFunctionRegistry
val function = functionInfo.getFunctionClass.newInstance().asInstanceOf[UDF]
val method = function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo))
- lazy val expectedDataTypes = method.getParameterTypes.map(javaClassToDataType)
+ val expectedDataTypes = method.getParameterTypes.map(javaClassToDataType)
HiveSimpleUdf(
functionClassName,
children.zip(expectedDataTypes).map {
case (e, NullType) => e
+ case (e, t) if (e.dataType == t) => e
case (e, t) => Cast(e, t)
}
)
diff --git a/sql/hive/src/test/resources/golden/Cast Timestamp to Timestamp in UDF-0-66952a3949d7544716fd1a675498b1fa b/sql/hive/src/test/resources/golden/Cast Timestamp to Timestamp in UDF-0-66952a3949d7544716fd1a675498b1fa
new file mode 100644
index 0000000000..7951defec1
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/Cast Timestamp to Timestamp in UDF-0-66952a3949d7544716fd1a675498b1fa
@@ -0,0 +1 @@
+NULL
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 8c8a8b124a..56bcd95eab 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -142,16 +142,25 @@ class HiveQuerySuite extends HiveComparisonTest {
setConf("spark.sql.dialect", "sql")
assert(sql("SELECT 1").collect() === Array(Seq(1)))
setConf("spark.sql.dialect", "hiveql")
-
}
test("Query expressed in HiveQL") {
sql("FROM src SELECT key").collect()
}
+ test("Query with constant folding the CAST") {
+ sql("SELECT CAST(CAST('123' AS binary) AS binary) FROM src LIMIT 1").collect()
+ }
+
createQueryTest("Constant Folding Optimization for AVG_SUM_COUNT",
"SELECT AVG(0), SUM(0), COUNT(null), COUNT(value) FROM src GROUP BY key")
+ createQueryTest("Cast Timestamp to Timestamp in UDF",
+ """
+ | SELECT DATEDIFF(CAST(value AS timestamp), CAST('2002-03-21 00:00:00' AS timestamp))
+ | FROM src LIMIT 1
+ """.stripMargin)
+
createQueryTest("Simple Average",
"SELECT AVG(key) FROM src")