aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorravipesala <ravindra.pesala@huawei.com>2014-10-03 11:25:18 -0700
committerMichael Armbrust <michael@databricks.com>2014-10-03 11:25:18 -0700
commit22f8e1ee7c4ea7b3bd4c6faaf0fe5b88a134ae12 (patch)
tree67ec0dcd79f419853cc0581e43e80b92d2b8f46e
parent9d320e222c221e5bb827cddf01a83e64a16d74ff (diff)
downloadspark-22f8e1ee7c4ea7b3bd4c6faaf0fe5b88a134ae12.tar.gz
spark-22f8e1ee7c4ea7b3bd4c6faaf0fe5b88a134ae12.tar.bz2
spark-22f8e1ee7c4ea7b3bd4c6faaf0fe5b88a134ae12.zip
[SPARK-2693][SQL] Supported for UDAF Hive Aggregates like PERCENTILE
Implemented UDAF Hive aggregates by adding wrapper to Spark Hive. Author: ravipesala <ravindra.pesala@huawei.com> Closes #2620 from ravipesala/SPARK-2693 and squashes the following commits: a8df326 [ravipesala] Removed resolver from constructor arguments caf25c6 [ravipesala] Fixed style issues 5786200 [ravipesala] Supported for UDAF Hive Aggregates like PERCENTILE
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala46
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala4
2 files changed, 46 insertions, 4 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index 732e4976f6..68f93f247d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -22,7 +22,7 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.hive.common.`type`.HiveDecimal
-import org.apache.hadoop.hive.ql.exec.UDF
+import org.apache.hadoop.hive.ql.exec.{UDF, UDAF}
import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry}
import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
import org.apache.hadoop.hive.ql.udf.generic._
@@ -57,7 +57,8 @@ private[hive] abstract class HiveFunctionRegistry
} else if (
classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUdaf(functionClassName, children)
-
+ } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
+ HiveUdaf(functionClassName, children)
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUdtf(functionClassName, Nil, children)
} else {
@@ -194,6 +195,37 @@ private[hive] case class HiveGenericUdaf(
def newInstance() = new HiveUdafFunction(functionClassName, children, this)
}
+/** It is used as a wrapper for the hive functions which uses UDAF interface */
+private[hive] case class HiveUdaf(
+ functionClassName: String,
+ children: Seq[Expression]) extends AggregateExpression
+ with HiveInspectors
+ with HiveFunctionFactory {
+
+ type UDFType = UDAF
+
+ @transient
+ protected lazy val resolver: AbstractGenericUDAFResolver = new GenericUDAFBridge(createFunction())
+
+ @transient
+ protected lazy val objectInspector = {
+ resolver.getEvaluator(children.map(_.dataType.toTypeInfo).toArray)
+ .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray)
+ }
+
+ @transient
+ protected lazy val inspectors = children.map(_.dataType).map(toInspector)
+
+ def dataType: DataType = inspectorToDataType(objectInspector)
+
+ def nullable: Boolean = true
+
+ override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
+
+ def newInstance() =
+ new HiveUdafFunction(functionClassName, children, this, true)
+}
+
/**
* Converts a Hive Generic User Defined Table Generating Function (UDTF) to a
* [[catalyst.expressions.Generator Generator]]. Note that the semantics of Generators do not allow
@@ -275,14 +307,20 @@ private[hive] case class HiveGenericUdtf(
private[hive] case class HiveUdafFunction(
functionClassName: String,
exprs: Seq[Expression],
- base: AggregateExpression)
+ base: AggregateExpression,
+ isUDAFBridgeRequired: Boolean = false)
extends AggregateFunction
with HiveInspectors
with HiveFunctionFactory {
def this() = this(null, null, null)
- private val resolver = createFunction[AbstractGenericUDAFResolver]()
+ private val resolver =
+ if (isUDAFBridgeRequired) {
+ new GenericUDAFBridge(createFunction[UDAF]())
+ } else {
+ createFunction[AbstractGenericUDAFResolver]()
+ }
private val inspectors = exprs.map(_.dataType).map(toInspector).toArray
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
index cc125d539c..e4324e9528 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
@@ -79,6 +79,10 @@ class HiveUdfSuite extends HiveComparisonTest {
sql("SELECT testUdf(pair) FROM hiveUdfTestTable")
sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf")
}
+
+ test("SPARK-2693 udaf aggregates test") {
+ assert(sql("SELECT percentile(key,1) FROM src").first === sql("SELECT max(key) FROM src").first)
+ }
}
class TestPair(x: Int, y: Int) extends Writable with Serializable {