diff options
Diffstat (limited to 'sql/hive')
-rw-r--r-- | sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala | 46 | ||||
-rw-r--r-- | sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala | 4 |
2 files changed, 46 insertions, 4 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 732e4976f6..68f93f247d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -22,7 +22,7 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.hadoop.hive.ql.exec.UDF +import org.apache.hadoop.hive.ql.exec.{UDF, UDAF} import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ @@ -57,7 +57,8 @@ private[hive] abstract class HiveFunctionRegistry } else if ( classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { HiveGenericUdaf(functionClassName, children) - + } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveUdaf(functionClassName, children) } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { HiveGenericUdtf(functionClassName, Nil, children) } else { @@ -194,6 +195,37 @@ private[hive] case class HiveGenericUdaf( def newInstance() = new HiveUdafFunction(functionClassName, children, this) } +/** It is used as a wrapper for the hive functions which uses UDAF interface */ +private[hive] case class HiveUdaf( + functionClassName: String, + children: Seq[Expression]) extends AggregateExpression + with HiveInspectors + with HiveFunctionFactory { + + type UDFType = UDAF + + @transient + protected lazy val resolver: AbstractGenericUDAFResolver = new GenericUDAFBridge(createFunction()) + + @transient + protected lazy val objectInspector = { + resolver.getEvaluator(children.map(_.dataType.toTypeInfo).toArray) + .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) + } + + @transient + protected lazy val inspectors = children.map(_.dataType).map(toInspector) + + def dataType: DataType = inspectorToDataType(objectInspector) + + def nullable: Boolean = true + + override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})" + + def newInstance() = + new HiveUdafFunction(functionClassName, children, this, true) +} + /** * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a * [[catalyst.expressions.Generator Generator]]. Note that the semantics of Generators do not allow @@ -275,14 +307,20 @@ private[hive] case class HiveGenericUdtf( private[hive] case class HiveUdafFunction( functionClassName: String, exprs: Seq[Expression], - base: AggregateExpression) + base: AggregateExpression, + isUDAFBridgeRequired: Boolean = false) extends AggregateFunction with HiveInspectors with HiveFunctionFactory { def this() = this(null, null, null) - private val resolver = createFunction[AbstractGenericUDAFResolver]() + private val resolver = + if (isUDAFBridgeRequired) { + new GenericUDAFBridge(createFunction[UDAF]()) + } else { + createFunction[AbstractGenericUDAFResolver]() + } private val inspectors = exprs.map(_.dataType).map(toInspector).toArray diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index cc125d539c..e4324e9528 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -79,6 +79,10 @@ class HiveUdfSuite extends HiveComparisonTest { sql("SELECT testUdf(pair) FROM hiveUdfTestTable") sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") } + + test("SPARK-2693 udaf aggregates test") { + assert(sql("SELECT percentile(key,1) FROM src").first === sql("SELECT max(key) FROM src").first) + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { |