aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala14
-rw-r--r--sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f551
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala8
4 files changed, 17 insertions, 8 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index ada980acb1..0eeac8620f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -326,6 +326,8 @@ private[hive] trait HiveInspectors {
})
ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, map)
}
+ case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].")
+ case _ if expr.foldable => toInspector(Literal(expr.eval(), expr.dataType))
case _ => toInspector(expr.dataType)
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index 86f7eea5df..b255a2ebb9 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -21,7 +21,7 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
import scala.collection.mutable.ArrayBuffer
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector
+import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector}
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory
import org.apache.hadoop.hive.ql.exec.{UDF, UDAF}
@@ -108,9 +108,7 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[
udfType != null && udfType.deterministic()
}
- override def foldable = {
- isUDFDeterministic && children.foldLeft(true)((prev, n) => prev && n.foldable)
- }
+ override def foldable = isUDFDeterministic && children.forall(_.foldable)
// Create parameter converters
@transient
@@ -154,7 +152,8 @@ private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq
protected lazy val argumentInspectors = children.map(toInspector)
@transient
- protected lazy val returnInspector = function.initialize(argumentInspectors.toArray)
+ protected lazy val returnInspector =
+ function.initializeAndFoldConstants(argumentInspectors.toArray)
@transient
protected lazy val isUDFDeterministic = {
@@ -162,9 +161,8 @@ private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq
(udfType != null && udfType.deterministic())
}
- override def foldable = {
- isUDFDeterministic && children.foldLeft(true)((prev, n) => prev && n.foldable)
- }
+ override def foldable =
+ isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector]
@transient
protected lazy val deferedObjects =
diff --git a/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 b/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55
new file mode 100644
index 0000000000..7bc77e7f2a
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55
@@ -0,0 +1 @@
+{"aa":"10","aaaaaa":"11","aaaaaa":"12","bb12":"13","s14s14":"14"}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 0dd766f253..af45dfd6e2 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -56,6 +56,14 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
Locale.setDefault(originalLocale)
}
+ createQueryTest("constant object inspector for generic udf",
+ """SELECT named_struct(
+ lower("AA"), "10",
+ repeat(lower("AA"), 3), "11",
+ lower(repeat("AA", 3)), "12",
+ printf("Bb%d", 12), "13",
+ repeat(printf("s%d", 14), 2), "14") FROM src LIMIT 1""")
+
createQueryTest("NaN to Decimal",
"SELECT CAST(CAST('NaN' AS DOUBLE) AS DECIMAL(1,1)) FROM src LIMIT 1")