aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@163.com>2015-09-24 09:54:07 -0700
committerYin Huai <yhuai@databricks.com>2015-09-24 09:54:07 -0700
commit341b13f8f5eb118f1fb4d4f84418715ac4750a4d (patch)
treee6b8d90482e7d8e6cc339a6f357d320e92d7e8ff
parent02144d6745ec0a6d8877d969feb82139bd22437f (diff)
downloadspark-341b13f8f5eb118f1fb4d4f84418715ac4750a4d.tar.gz
spark-341b13f8f5eb118f1fb4d4f84418715ac4750a4d.tar.bz2
spark-341b13f8f5eb118f1fb4d4f84418715ac4750a4d.zip
[SPARK-10765] [SQL] use new aggregate interface for hive UDAF
Author: Wenchen Fan <cloud0fan@163.com> Closes #8874 from cloud-fan/hive-agg.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala51
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala139
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala2
6 files changed, 129 insertions, 86 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 576d8c7a3a..d8699533cd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.InternalRow
@@ -169,6 +168,12 @@ abstract class AggregateFunction2
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
+
+ /**
+ * Indicates if this function supports partial aggregation.
+ * Currently Hive UDAF is the only one that doesn't support partial aggregation.
+ */
+ def supportsPartial: Boolean = true
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 41b215c792..b078c8b6b0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -221,7 +221,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
val aggregateOperator =
- if (functionsWithDistinct.isEmpty) {
+ if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
+ if (functionsWithDistinct.nonEmpty) {
+ sys.error("Distinct columns cannot exist in Aggregate operator containing " +
+ "aggregate functions which don't support partial aggregation.")
+ } else {
+ aggregate.Utils.planAggregateWithoutPartial(
+ groupingExpressions,
+ aggregateExpressions,
+ aggregateFunctionMap,
+ resultExpressions,
+ planLater(child))
+ }
+ } else if (functionsWithDistinct.isEmpty) {
aggregate.Utils.planAggregateWithoutDistinct(
groupingExpressions,
aggregateExpressions,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
index abca373b0c..62dbc07e88 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -26,7 +26,7 @@ import org.apache.spark.unsafe.KVIterator
import scala.collection.mutable.ArrayBuffer
/**
- * The base class of [[SortBasedAggregationIterator]] and [[UnsafeHybridAggregationIterator]].
+ * The base class of [[SortBasedAggregationIterator]].
* It mainly contains two parts:
* 1. It initializes aggregate functions.
* 2. It creates two functions, `processRow` and `generateOutput` based on [[AggregateMode]] of
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 80816a095e..4f5e86cceb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -37,6 +37,57 @@ object Utils {
UnsafeProjection.canSupport(groupingExpressions)
}
+ def planAggregateWithoutPartial(
+ groupingExpressions: Seq[Expression],
+ aggregateExpressions: Seq[AggregateExpression2],
+ aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)],
+ resultExpressions: Seq[NamedExpression],
+ child: SparkPlan): Seq[SparkPlan] = {
+
+ val namedGroupingExpressions = groupingExpressions.map {
+ case ne: NamedExpression => ne -> ne
+ // If the expression is not a NamedExpressions, we add an alias.
+ // So, when we generate the result of the operator, the Aggregate Operator
+ // can directly get the Seq of attributes representing the grouping expressions.
+ case other =>
+ val withAlias = Alias(other, other.toString)()
+ other -> withAlias
+ }
+ val groupExpressionMap = namedGroupingExpressions.toMap
+ val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
+
+ val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete))
+ val completeAggregateAttributes =
+ completeAggregateExpressions.map {
+ expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2
+ }
+
+ val rewrittenResultExpressions = resultExpressions.map { expr =>
+ expr.transformDown {
+ case agg: AggregateExpression2 =>
+ aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2
+ case expression =>
+ // We do not rely on the equality check at here since attributes may
+ // different cosmetically. Instead, we use semanticEquals.
+ groupExpressionMap.collectFirst {
+ case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+ }.getOrElse(expression)
+ }.asInstanceOf[NamedExpression]
+ }
+
+ SortBasedAggregate(
+ requiredChildDistributionExpressions = Some(namedGroupingAttributes),
+ groupingExpressions = namedGroupingExpressions.map(_._2),
+ nonCompleteAggregateExpressions = Nil,
+ nonCompleteAggregateAttributes = Nil,
+ completeAggregateExpressions = completeAggregateExpressions,
+ completeAggregateAttributes = completeAggregateAttributes,
+ initialInputBufferOffset = 0,
+ resultExpressions = rewrittenResultExpressions,
+ child = child
+ ) :: Nil
+ }
+
def planAggregateWithoutDistinct(
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[AggregateExpression2],
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index cad02373e5..fa9012b96e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
@@ -65,9 +66,10 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry)
HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children)
} else if (
classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveGenericUDAF(new HiveFunctionWrapper(functionClassName), children)
+ HiveUDAFFunction(new HiveFunctionWrapper(functionClassName), children)
} else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveUDAF(new HiveFunctionWrapper(functionClassName), children)
+ HiveUDAFFunction(
+ new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true)
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children)
} else {
@@ -441,70 +443,6 @@ private[hive] case class HiveWindowFunction(
new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children)
}
-private[hive] case class HiveGenericUDAF(
- funcWrapper: HiveFunctionWrapper,
- children: Seq[Expression]) extends AggregateExpression1
- with HiveInspectors {
-
- type UDFType = AbstractGenericUDAFResolver
-
- @transient
- protected lazy val resolver: AbstractGenericUDAFResolver = funcWrapper.createFunction()
-
- @transient
- protected lazy val objectInspector = {
- val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false)
- resolver.getEvaluator(parameterInfo)
- .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray)
- }
-
- @transient
- protected lazy val inspectors = children.map(toInspector)
-
- def dataType: DataType = inspectorToDataType(objectInspector)
-
- def nullable: Boolean = true
-
- override def toString: String = {
- s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
- }
-
- def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this)
-}
-
-/** It is used as a wrapper for the hive functions which uses UDAF interface */
-private[hive] case class HiveUDAF(
- funcWrapper: HiveFunctionWrapper,
- children: Seq[Expression]) extends AggregateExpression1
- with HiveInspectors {
-
- type UDFType = UDAF
-
- @transient
- protected lazy val resolver: AbstractGenericUDAFResolver =
- new GenericUDAFBridge(funcWrapper.createFunction())
-
- @transient
- protected lazy val objectInspector = {
- val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false)
- resolver.getEvaluator(parameterInfo)
- .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray)
- }
-
- @transient
- protected lazy val inspectors = children.map(toInspector)
-
- def dataType: DataType = inspectorToDataType(objectInspector)
-
- def nullable: Boolean = true
-
- override def toString: String = {
- s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
- }
-
- def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this, true)
-}
-
/**
* Converts a Hive Generic User Defined Table Generating Function (UDTF) to a
* [[Generator]]. Note that the semantics of Generators do not allow
@@ -584,49 +522,86 @@ private[hive] case class HiveGenericUDTF(
}
}
+/**
+ * Currently we don't support partial aggregation for queries using Hive UDAF, which may hurt
+ * performance a lot.
+ */
private[hive] case class HiveUDAFFunction(
funcWrapper: HiveFunctionWrapper,
- exprs: Seq[Expression],
- base: AggregateExpression1,
+ children: Seq[Expression],
isUDAFBridgeRequired: Boolean = false)
- extends AggregateFunction1
- with HiveInspectors {
+ extends AggregateFunction2 with HiveInspectors {
- def this() = this(null, null, null)
+ def this() = this(null, null)
- private val resolver =
+ @transient
+ private lazy val resolver =
if (isUDAFBridgeRequired) {
new GenericUDAFBridge(funcWrapper.createFunction[UDAF]())
} else {
funcWrapper.createFunction[AbstractGenericUDAFResolver]()
}
- private val inspectors = exprs.map(toInspector).toArray
+ @transient
+ private lazy val inspectors = children.map(toInspector).toArray
- private val function = {
+ @transient
+ private lazy val functionAndInspector = {
val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false)
- resolver.getEvaluator(parameterInfo)
+ val f = resolver.getEvaluator(parameterInfo)
+ f -> f.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors)
}
- private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors)
+ @transient
+ private lazy val function = functionAndInspector._1
+
+ @transient
+ private lazy val returnInspector = functionAndInspector._2
- private val buffer =
- function.getNewAggregationBuffer
+ @transient
+ private[this] var buffer: GenericUDAFEvaluator.AggregationBuffer = _
override def eval(input: InternalRow): Any = unwrap(function.evaluate(buffer), returnInspector)
@transient
- val inputProjection = new InterpretedProjection(exprs)
+ private lazy val inputProjection = new InterpretedProjection(children)
@transient
- protected lazy val cached = new Array[AnyRef](exprs.length)
+ private lazy val cached = new Array[AnyRef](children.length)
@transient
- private lazy val inputDataTypes: Array[DataType] = exprs.map(_.dataType).toArray
+ private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray
+
+ // Hive UDAF has its own buffer, so we don't need to occupy a slot in the aggregation
+ // buffer for it.
+ override def bufferSchema: StructType = StructType(Nil)
- def update(input: InternalRow): Unit = {
+ override def update(_buffer: MutableRow, input: InternalRow): Unit = {
val inputs = inputProjection(input)
function.iterate(buffer, wrap(inputs, inspectors, cached, inputDataTypes))
}
+
+ override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
+ throw new UnsupportedOperationException(
+ "Hive UDAF doesn't support partial aggregate")
+ }
+
+ override def cloneBufferAttributes: Seq[Attribute] = Nil
+
+ override def initialize(_buffer: MutableRow): Unit = {
+ buffer = function.getNewAggregationBuffer
+ }
+
+ override def bufferAttributes: Seq[AttributeReference] = Nil
+
+ // We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our
+ // catalyst type checking framework.
+ override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType)
+
+ override def nullable: Boolean = true
+
+ override def supportsPartial: Boolean = false
+
+ override lazy val dataType: DataType = inspectorToDataType(returnInspector)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index d9ba895e1e..3c8a0091c8 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -131,7 +131,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton {
hiveContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault)
}
- test("SPARK-6409 UDAFAverage test") {
+ test("SPARK-6409 UDAF Average test") {
sql(s"CREATE TEMPORARY FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'")
checkAnswer(
sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"),