diff options
Diffstat (limited to 'sql')
-rw-r--r-- | sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala | 68 | ||||
-rw-r--r-- | sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala | 59 |
2 files changed, 104 insertions, 23 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index fa9012b96e..a85d4db88d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -60,20 +60,36 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) val functionClassName = functionInfo.getFunctionClass.getName - if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveSimpleUDF(new HiveFunctionWrapper(functionClassName), children) - } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children) - } else if ( - classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUDAFFunction(new HiveFunctionWrapper(functionClassName), children) - } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUDAFFunction( - new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true) - } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children) - } else { - sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") + // When we instantiate hive UDF wrapper class, we may throw exception if the input expressions + // don't satisfy the hive UDF, such as type mismatch, input number mismatch, etc. Here we + // catch the exception and throw AnalysisException instead. + try { + if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveSimpleUDF(new HiveFunctionWrapper(functionClassName), children) + } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children) + } else if ( + classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveUDAFFunction(new HiveFunctionWrapper(functionClassName), children) + } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveUDAFFunction( + new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true) + } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { + val udtf = HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children) + udtf.elementTypes // Force it to check input data types. + udtf + } else { + throw new AnalysisException(s"No handler for udf ${functionInfo.getFunctionClass}") + } + } catch { + case analysisException: AnalysisException => + // If the exception is an AnalysisException, just throw it. + throw analysisException + case throwable: Throwable => + // If there is any other error, we throw an AnalysisException. + val errorMessage = s"No handler for Hive udf ${functionInfo.getFunctionClass} " + + s"because: ${throwable.getMessage}." + throw new AnalysisException(errorMessage) } } } @@ -134,7 +150,7 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre @transient private lazy val conversionHelper = new ConversionHelper(method, arguments) - val dataType = javaClassToDataType(method.getReturnType) + override val dataType = javaClassToDataType(method.getReturnType) @transient lazy val returnInspector = ObjectInspectorFactory.getReflectionObjectInspector( @@ -205,7 +221,7 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr new DeferredObjectAdapter(inspect, child.dataType) }.toArray[DeferredObject] - lazy val dataType: DataType = inspectorToDataType(returnInspector) + override val dataType: DataType = inspectorToDataType(returnInspector) override def eval(input: InternalRow): Any = { returnInspector // Make sure initialized. @@ -231,6 +247,12 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr * Resolves [[UnresolvedWindowFunction]] to [[HiveWindowFunction]]. */ private[spark] object ResolveHiveWindowFunction extends Rule[LogicalPlan] { + private def shouldResolveFunction( + unresolvedWindowFunction: UnresolvedWindowFunction, + windowSpec: WindowSpecDefinition): Boolean = { + unresolvedWindowFunction.childrenResolved && windowSpec.childrenResolved + } + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p: LogicalPlan if !p.childrenResolved => p @@ -238,9 +260,11 @@ private[spark] object ResolveHiveWindowFunction extends Rule[LogicalPlan] { // replaced those WindowSpecReferences. case p: LogicalPlan => p transformExpressions { + // We will not start to resolve the function unless all arguments are resolved + // and all expressions in window spec are fixed. case WindowExpression( - UnresolvedWindowFunction(name, children), - windowSpec: WindowSpecDefinition) => + u @ UnresolvedWindowFunction(name, children), + windowSpec: WindowSpecDefinition) if shouldResolveFunction(u, windowSpec) => // First, let's find the window function info. val windowFunctionInfo: WindowFunctionInfo = Option(FunctionRegistry.getWindowFunctionInfo(name.toLowerCase)).getOrElse( @@ -256,7 +280,7 @@ private[spark] object ResolveHiveWindowFunction extends Rule[LogicalPlan] { // are expressions in Order By clause. if (classOf[GenericUDAFRank].isAssignableFrom(functionClass)) { if (children.nonEmpty) { - throw new AnalysisException(s"$name does not take input parameters.") + throw new AnalysisException(s"$name does not take input parameters.") } windowSpec.orderSpec.map(_.child) } else { @@ -358,7 +382,7 @@ private[hive] case class HiveWindowFunction( evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) } - override def dataType: DataType = + override val dataType: DataType = if (!pivotResult) { inspectorToDataType(returnInspector) } else { @@ -478,7 +502,7 @@ private[hive] case class HiveGenericUDTF( @transient protected lazy val collector = new UDTFCollector - lazy val elementTypes = outputInspector.getAllStructFieldRefs.asScala.map { + override lazy val elementTypes = outputInspector.getAllStructFieldRefs.asScala.map { field => (inspectorToDataType(field.getFieldObjectInspector), true) } @@ -602,6 +626,6 @@ private[hive] case class HiveUDAFFunction( override def supportsPartial: Boolean = false - override lazy val dataType: DataType = inspectorToDataType(returnInspector) + override val dataType: DataType = inspectorToDataType(returnInspector) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 3c8a0091c8..5f9a447759 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -21,7 +21,8 @@ import java.io.{DataInput, DataOutput} import java.util.{ArrayList, Arrays, Properties} import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFAverage, GenericUDF} +import org.apache.hadoop.hive.ql.udf.UDAFPercentile +import org.apache.hadoop.hive.ql.udf.generic.{GenericUDFOPAnd, GenericUDTFExplode, GenericUDAFAverage, GenericUDF} import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} @@ -299,6 +300,62 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton { hiveContext.reset() } + + test("Hive UDFs with insufficient number of input arguments should trigger an analysis error") { + Seq((1, 2)).toDF("a", "b").registerTempTable("testUDF") + + { + // HiveSimpleUDF + sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") + val message = intercept[AnalysisException] { + sql("SELECT testUDFTwoListList() FROM testUDF") + }.getMessage + assert(message.contains("No handler for Hive udf")) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") + } + + { + // HiveGenericUDF + sql(s"CREATE TEMPORARY FUNCTION testUDFAnd AS '${classOf[GenericUDFOPAnd].getName}'") + val message = intercept[AnalysisException] { + sql("SELECT testUDFAnd() FROM testUDF") + }.getMessage + assert(message.contains("No handler for Hive udf")) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFAnd") + } + + { + // Hive UDAF + sql(s"CREATE TEMPORARY FUNCTION testUDAFPercentile AS '${classOf[UDAFPercentile].getName}'") + val message = intercept[AnalysisException] { + sql("SELECT testUDAFPercentile(a) FROM testUDF GROUP BY b") + }.getMessage + assert(message.contains("No handler for Hive udf")) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDAFPercentile") + } + + { + // AbstractGenericUDAFResolver + sql(s"CREATE TEMPORARY FUNCTION testUDAFAverage AS '${classOf[GenericUDAFAverage].getName}'") + val message = intercept[AnalysisException] { + sql("SELECT testUDAFAverage() FROM testUDF GROUP BY b") + }.getMessage + assert(message.contains("No handler for Hive udf")) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDAFAverage") + } + + { + // Hive UDTF + sql(s"CREATE TEMPORARY FUNCTION testUDTFExplode AS '${classOf[GenericUDTFExplode].getName}'") + val message = intercept[AnalysisException] { + sql("SELECT testUDTFExplode() FROM testUDF") + }.getMessage + assert(message.contains("No handler for Hive udf")) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDTFExplode") + } + + sqlContext.dropTempTable("testUDF") + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { |