From 0f6a2eeaf20363061f9ed2d9062f3a7022c2c8ba Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 20 Nov 2014 16:50:59 -0800 Subject: [SPARK-4244] [SQL] Support Hive Generic UDFs with constant object inspector parameters Query `SELECT named_struct(lower("AA"), "12", lower("Bb"), "13") FROM src LIMIT 1` will throw exception, some of the Hive Generic UDF/UDAF requires the input object inspector is `ConstantObjectInspector`, however, we won't get that before the expression optimization executed. (Constant Folding). This PR is a work around to fix this. (As ideally, the `output` of LogicalPlan should be identical before and after Optimization). Author: Cheng Hao Closes #3109 from chenghao-intel/optimized and squashes the following commits: 487ff79 [Cheng Hao] rebase to the latest master & update the unittest (cherry picked from commit 84d79ee9ec47465269f7b0a7971176da93c96f3f) Signed-off-by: Michael Armbrust --- .../scala/org/apache/spark/sql/hive/HiveInspectors.scala | 2 ++ .../main/scala/org/apache/spark/sql/hive/hiveUdfs.scala | 14 ++++++-------- ...ctor for generic udf-0-cc120a2331158f570a073599985d3f55 | 1 + .../apache/spark/sql/hive/execution/HiveQuerySuite.scala | 8 ++++++++ 4 files changed, 17 insertions(+), 8 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 (limited to 'sql') diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index ada980acb1..0eeac8620f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -326,6 +326,8 @@ private[hive] trait HiveInspectors { }) ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, map) } + case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].") + case _ if expr.foldable => toInspector(Literal(expr.eval(), expr.dataType)) case _ => toInspector(expr.dataType) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 86f7eea5df..b255a2ebb9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -21,7 +21,7 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper import scala.collection.mutable.ArrayBuffer -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory import org.apache.hadoop.hive.ql.exec.{UDF, UDAF} @@ -108,9 +108,7 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[ udfType != null && udfType.deterministic() } - override def foldable = { - isUDFDeterministic && children.foldLeft(true)((prev, n) => prev && n.foldable) - } + override def foldable = isUDFDeterministic && children.forall(_.foldable) // Create parameter converters @transient @@ -154,7 +152,8 @@ private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq protected lazy val argumentInspectors = children.map(toInspector) @transient - protected lazy val returnInspector = function.initialize(argumentInspectors.toArray) + protected lazy val returnInspector = + function.initializeAndFoldConstants(argumentInspectors.toArray) @transient protected lazy val isUDFDeterministic = { @@ -162,9 +161,8 @@ private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq (udfType != null && udfType.deterministic()) } - override def foldable = { - isUDFDeterministic && children.foldLeft(true)((prev, n) => prev && n.foldable) - } + override def foldable = + isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] @transient protected lazy val deferedObjects = diff --git a/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 b/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 new file mode 100644 index 0000000000..7bc77e7f2a --- /dev/null +++ b/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 @@ -0,0 +1 @@ +{"aa":"10","aaaaaa":"11","aaaaaa":"12","bb12":"13","s14s14":"14"} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 0dd766f253..af45dfd6e2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -56,6 +56,14 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { Locale.setDefault(originalLocale) } + createQueryTest("constant object inspector for generic udf", + """SELECT named_struct( + lower("AA"), "10", + repeat(lower("AA"), 3), "11", + lower(repeat("AA", 3)), "12", + printf("Bb%d", 12), "13", + repeat(printf("s%d", 14), 2), "14") FROM src LIMIT 1""") + createQueryTest("NaN to Decimal", "SELECT CAST(CAST('NaN' AS DOUBLE) AS DECIMAL(1,1)) FROM src LIMIT 1") -- cgit v1.2.3