aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@questtec.nl>2015-11-29 14:13:11 -0800
committerYin Huai <yhuai@databricks.com>2015-11-29 14:13:11 -0800
commit3d28081e53698ed77e93c04299957c02bcaba9bf (patch)
tree8c4c791b93a06e975a31b140784c38bc6980b303
parentcc7a1bc9370b163f51230e5ca4be612d133a5086 (diff)
downloadspark-3d28081e53698ed77e93c04299957c02bcaba9bf.tar.gz
spark-3d28081e53698ed77e93c04299957c02bcaba9bf.tar.bz2
spark-3d28081e53698ed77e93c04299957c02bcaba9bf.zip
[SPARK-12024][SQL] More efficient multi-column counting.
In https://github.com/apache/spark/pull/9409 we enabled multi-column counting. The approach taken in that PR introduces a bit of overhead by first creating a row only to check if all of the columns are non-null. This PR fixes that technical debt. Count now takes multiple columns as its input. In order to make this work I have also added support for multiple columns in the single distinct code path. cc yhuai Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #10015 from hvanhovell/SPARK-12024.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala21
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala27
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala14
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala39
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala4
6 files changed, 33 insertions, 86 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
index 09a1da9200..441f52ab5c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
@@ -21,8 +21,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
-case class Count(child: Expression) extends DeclarativeAggregate {
- override def children: Seq[Expression] = child :: Nil
+case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
override def nullable: Boolean = false
@@ -30,7 +29,7 @@ case class Count(child: Expression) extends DeclarativeAggregate {
override def dataType: DataType = LongType
// Expected input data type.
- override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+ override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType)
private lazy val count = AttributeReference("count", LongType)()
@@ -41,7 +40,7 @@ case class Count(child: Expression) extends DeclarativeAggregate {
)
override lazy val updateExpressions = Seq(
- /* count = */ If(IsNull(child), count, count + 1L)
+ /* count = */ If(children.map(IsNull).reduce(Or), count, count + 1L)
)
override lazy val mergeExpressions = Seq(
@@ -54,17 +53,5 @@ case class Count(child: Expression) extends DeclarativeAggregate {
}
object Count {
- def apply(children: Seq[Expression]): Count = {
- // This is used to deal with COUNT DISTINCT. When we have multiple
- // children (COUNT(DISTINCT col1, col2, ...)), we wrap them in a STRUCT (i.e. a Row).
- // Also, the semantic of COUNT(DISTINCT col1, col2, ...) is that if there is any
- // null in the arguments, we will not count that row. So, we use DropAnyNull at here
- // to return a null when any field of the created STRUCT is null.
- val child = if (children.size > 1) {
- DropAnyNull(CreateStruct(children))
- } else {
- children.head
- }
- Count(child)
- }
+ def apply(child: Expression): Count = Count(child :: Nil)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 694a2a7c54..40b1eec63e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -426,30 +426,3 @@ case class Greatest(children: Seq[Expression]) extends Expression {
}
}
-/** Operator that drops a row when it contains any nulls. */
-case class DropAnyNull(child: Expression) extends UnaryExpression with ExpectsInputTypes {
- override def nullable: Boolean = true
- override def dataType: DataType = child.dataType
- override def inputTypes: Seq[AbstractDataType] = Seq(StructType)
-
- protected override def nullSafeEval(input: Any): InternalRow = {
- val row = input.asInstanceOf[InternalRow]
- if (row.anyNull) {
- null
- } else {
- row
- }
- }
-
- override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- nullSafeCodeGen(ctx, ev, eval => {
- s"""
- if ($eval.anyNull()) {
- ${ev.isNull} = true;
- } else {
- ${ev.value} = $eval;
- }
- """
- })
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 2901d8f2ef..06d14fcf8b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -362,9 +362,14 @@ object LikeSimplification extends Rule[LogicalPlan] {
* Null value propagation from bottom to top of the expression tree.
*/
object NullPropagation extends Rule[LogicalPlan] {
+ def nonNullLiteral(e: Expression): Boolean = e match {
+ case Literal(null, _) => false
+ case _ => true
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
- case e @ AggregateExpression(Count(Literal(null, _)), _, _) =>
+ case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) =>
Cast(Literal(0L), e.dataType)
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
@@ -377,16 +382,13 @@ object NullPropagation extends Rule[LogicalPlan] {
Literal.create(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
- case e @ AggregateExpression(Count(expr), mode, false) if !expr.nullable =>
+ case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable) =>
// This rule should be only triggered when isDistinct field is false.
AggregateExpression(Count(Literal(1)), mode, isDistinct = false)
// For Coalesce, remove null literals.
case e @ Coalesce(children) =>
- val newChildren = children.filter {
- case Literal(null, _) => false
- case _ => true
- }
+ val newChildren = children.filter(nonNullLiteral)
if (newChildren.length == 0) {
Literal.create(null, e.dataType)
} else if (newChildren.length == 1) {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
index c1e3c17b87..0df673bb9f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
@@ -231,18 +231,4 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2)
}
}
-
- test("function dropAnyNull") {
- val drop = DropAnyNull(CreateStruct(Seq('a.string.at(0), 'b.string.at(1))))
- val a = create_row("a", "q")
- val nullStr: String = null
- checkEvaluation(drop, a, a)
- checkEvaluation(drop, null, create_row("b", nullStr))
- checkEvaluation(drop, null, create_row(nullStr, nullStr))
-
- val row = 'r.struct(
- StructField("a", StringType, false),
- StructField("b", StringType, true)).at(0)
- checkEvaluation(DropAnyNull(row), null, create_row(null))
- }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index a70e41436c..76b938cdb6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -146,20 +146,16 @@ object Utils {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
// functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one
- // DISTINCT aggregate function, all of those functions will have the same column expression.
+ // DISTINCT aggregate function, all of those functions will have the same column expressions.
// For example, it would be valid for functionsWithDistinct to be
// [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is
// disallowed because those two distinct aggregates have different column expressions.
- val distinctColumnExpression: Expression = {
- val allDistinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children
- assert(allDistinctColumnExpressions.length == 1)
- allDistinctColumnExpressions.head
- }
- val namedDistinctColumnExpression: NamedExpression = distinctColumnExpression match {
+ val distinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children
+ val namedDistinctColumnExpressions = distinctColumnExpressions.map {
case ne: NamedExpression => ne
case other => Alias(other, other.toString)()
}
- val distinctColumnAttribute: Attribute = namedDistinctColumnExpression.toAttribute
+ val distinctColumnAttributes = namedDistinctColumnExpressions.map(_.toAttribute)
val groupingAttributes = groupingExpressions.map(_.toAttribute)
// 1. Create an Aggregate Operator for partial aggregations.
@@ -170,10 +166,11 @@ object Utils {
// We will group by the original grouping expression, plus an additional expression for the
// DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping
// expressions will be [key, value].
- val partialAggregateGroupingExpressions = groupingExpressions :+ namedDistinctColumnExpression
+ val partialAggregateGroupingExpressions =
+ groupingExpressions ++ namedDistinctColumnExpressions
val partialAggregateResult =
groupingAttributes ++
- Seq(distinctColumnAttribute) ++
+ distinctColumnAttributes ++
partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
if (usesTungstenAggregate) {
TungstenAggregate(
@@ -208,28 +205,28 @@ object Utils {
partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
val partialMergeAggregateResult =
groupingAttributes ++
- Seq(distinctColumnAttribute) ++
+ distinctColumnAttributes ++
partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
if (usesTungstenAggregate) {
TungstenAggregate(
requiredChildDistributionExpressions = Some(groupingAttributes),
- groupingExpressions = groupingAttributes :+ distinctColumnAttribute,
+ groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
completeAggregateExpressions = Nil,
completeAggregateAttributes = Nil,
- initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
+ initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
resultExpressions = partialMergeAggregateResult,
child = partialAggregate)
} else {
SortBasedAggregate(
requiredChildDistributionExpressions = Some(groupingAttributes),
- groupingExpressions = groupingAttributes :+ distinctColumnAttribute,
+ groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
completeAggregateExpressions = Nil,
completeAggregateAttributes = Nil,
- initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
+ initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
resultExpressions = partialMergeAggregateResult,
child = partialAggregate)
}
@@ -244,14 +241,16 @@ object Utils {
expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
}
+ val distinctColumnAttributeLookup =
+ distinctColumnExpressions.zip(distinctColumnAttributes).toMap
val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map {
// Children of an AggregateFunction with DISTINCT keyword has already
// been evaluated. At here, we need to replace original children
// to AttributeReferences.
case agg @ AggregateExpression(aggregateFunction, mode, true) =>
- val rewrittenAggregateFunction = aggregateFunction.transformDown {
- case expr if expr == distinctColumnExpression => distinctColumnAttribute
- }.asInstanceOf[AggregateFunction]
+ val rewrittenAggregateFunction = aggregateFunction
+ .transformDown(distinctColumnAttributeLookup)
+ .asInstanceOf[AggregateFunction]
// We rewrite the aggregate function to a non-distinct aggregation because
// its input will have distinct arguments.
// We just keep the isDistinct setting to true, so when users look at the query plan,
@@ -270,7 +269,7 @@ object Utils {
nonCompleteAggregateAttributes = finalAggregateAttributes,
completeAggregateExpressions = completeAggregateExpressions,
completeAggregateAttributes = completeAggregateAttributes,
- initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
+ initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
resultExpressions = resultExpressions,
child = partialMergeAggregate)
} else {
@@ -281,7 +280,7 @@ object Utils {
nonCompleteAggregateAttributes = finalAggregateAttributes,
completeAggregateExpressions = completeAggregateExpressions,
completeAggregateAttributes = completeAggregateAttributes,
- initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
+ initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
resultExpressions = resultExpressions,
child = partialMergeAggregate)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
index fc873c04f8..893e800a61 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
@@ -152,8 +152,8 @@ class WindowSpec private[sql](
case Sum(child) => WindowExpression(
UnresolvedWindowFunction("sum", child :: Nil),
WindowSpecDefinition(partitionSpec, orderSpec, frame))
- case Count(child) => WindowExpression(
- UnresolvedWindowFunction("count", child :: Nil),
+ case Count(children) => WindowExpression(
+ UnresolvedWindowFunction("count", children),
WindowSpecDefinition(partitionSpec, orderSpec, frame))
case First(child, ignoreNulls) => WindowExpression(
// TODO this is a hack for Hive UDAF first_value