aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-10-14 17:27:27 -0700
committerYin Huai <yhuai@databricks.com>2015-10-14 17:27:50 -0700
commit4ace4f8a9c91beb21a0077e12b75637a4560a542 (patch)
tree3d9c2224cfec1cc0d839114e72b52856d68b8356 /sql/catalyst
parent1baaf2b9bd7c949a8f95cd14fc1be2a56e1139b3 (diff)
downloadspark-4ace4f8a9c91beb21a0077e12b75637a4560a542.tar.gz
spark-4ace4f8a9c91beb21a0077e12b75637a4560a542.tar.bz2
spark-4ace4f8a9c91beb21a0077e12b75637a4560a542.zip
[SPARK-11017] [SQL] Support ImperativeAggregates in TungstenAggregate
This patch extends TungstenAggregate to support ImperativeAggregate functions. The existing TungstenAggregate operator only supported DeclarativeAggregate functions, which are defined in terms of Catalyst expressions and can be evaluated via generated projections. ImperativeAggregate functions, on the other hand, are evaluated by calling their `initialize`, `update`, `merge`, and `eval` methods. The basic strategy here is similar to how SortBasedAggregate evaluates both types of aggregate functions: use a generated projection to evaluate the expression-based declarative aggregates with dummy placeholder expressions inserted in place of the imperative aggregate function output, then invoke the imperative aggregate functions and target them against the aggregation buffer. The bulk of the diff here consists of code that was copied and adapted from SortBasedAggregate, with some key changes to handle TungstenAggregate's sort fallback path. Author: Josh Rosen <joshrosen@databricks.com> Closes #9038 from JoshRosen/support-interpreted-in-tungsten-agg-final.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala31
2 files changed, 37 insertions, 13 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
index 8aad0b7dee..c0bc7ec09c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
@@ -472,10 +472,20 @@ case class Sum(child: Expression) extends DeclarativeAggregate {
* @param relativeSD the maximum estimation error allowed.
*/
// scalastyle:on
-case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05)
- extends ImperativeAggregate {
+case class HyperLogLogPlusPlus(
+ child: Expression,
+ relativeSD: Double = 0.05,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0)
+ extends ImperativeAggregate {
import HyperLogLogPlusPlus._
+ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+
/**
* HLL++ uses 'p' bits for addressing. The more addressing bits we use, the more precise the
* algorithm will be, and the more memory it will require. The 'p' value is based on the relative
@@ -546,6 +556,11 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05)
AttributeReference(s"MS[$i]", LongType)()
}
+ // Note: although this simply copies aggBufferAttributes, this common code can not be placed
+ // in the superclass because that will lead to initialization ordering issues.
+ override val inputAggBufferAttributes: Seq[AttributeReference] =
+ aggBufferAttributes.map(_.newInstance())
+
/** Fill all words with zeros. */
override def initialize(buffer: MutableRow): Unit = {
var word = 0
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 9ba3a9c980..a2fab258fc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -150,6 +150,10 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp
* We need to perform similar field number arithmetic when merging multiple intermediate
* aggregate buffers together in `merge()` (in this case, use `inputAggBufferOffset` when accessing
* the input buffer).
+ *
+ * Correct ImperativeAggregate evaluation depends on the correctness of `mutableAggBufferOffset` and
+ * `inputAggBufferOffset`, but not on the correctness of the attribute ids in `aggBufferAttributes`
+ * and `inputAggBufferAttributes`.
*/
abstract class ImperativeAggregate extends AggregateFunction2 {
@@ -172,11 +176,13 @@ abstract class ImperativeAggregate extends AggregateFunction2 {
* avg(y) mutableAggBufferOffset = 2
*
*/
- protected var mutableAggBufferOffset: Int = 0
+ protected val mutableAggBufferOffset: Int
- def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): Unit = {
- mutableAggBufferOffset = newMutableAggBufferOffset
- }
+ /**
+ * Returns a copy of this ImperativeAggregate with an updated mutableAggBufferOffset.
+ * This new copy's attributes may have different ids than the original.
+ */
+ def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate
/**
* The offset of this function's start buffer value in the underlying shared input aggregation
@@ -203,11 +209,17 @@ abstract class ImperativeAggregate extends AggregateFunction2 {
* avg(y) inputAggBufferOffset = 3
*
*/
- protected var inputAggBufferOffset: Int = 0
+ protected val inputAggBufferOffset: Int
- def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): Unit = {
- inputAggBufferOffset = newInputAggBufferOffset
- }
+ /**
+ * Returns a copy of this ImperativeAggregate with an updated mutableAggBufferOffset.
+ * This new copy's attributes may have different ids than the original.
+ */
+ def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate
+
+ // Note: although all subclasses implement inputAggBufferAttributes by simply cloning
+ // aggBufferAttributes, that common clone code cannot be placed here in the abstract
+ // ImperativeAggregate class, since that will lead to initialization ordering issues.
/**
* Initializes the mutable aggregation buffer located in `mutableAggBuffer`.
@@ -231,9 +243,6 @@ abstract class ImperativeAggregate extends AggregateFunction2 {
* Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`.
*/
def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit
-
- final lazy val inputAggBufferAttributes: Seq[AttributeReference] =
- aggBufferAttributes.map(_.newInstance())
}
/**