diff options
author | Herman van Hovell <hvanhovell@questtec.nl> | 2015-11-29 14:13:11 -0800 |
---|---|---|
committer | Yin Huai <yhuai@databricks.com> | 2015-11-29 14:13:11 -0800 |
commit | 3d28081e53698ed77e93c04299957c02bcaba9bf (patch) | |
tree | 8c4c791b93a06e975a31b140784c38bc6980b303 /sql/catalyst/src | |
parent | cc7a1bc9370b163f51230e5ca4be612d133a5086 (diff) | |
download | spark-3d28081e53698ed77e93c04299957c02bcaba9bf.tar.gz spark-3d28081e53698ed77e93c04299957c02bcaba9bf.tar.bz2 spark-3d28081e53698ed77e93c04299957c02bcaba9bf.zip |
[SPARK-12024][SQL] More efficient multi-column counting.
In https://github.com/apache/spark/pull/9409 we enabled multi-column counting. The approach taken in that PR introduces a bit of overhead by first creating a row only to check if all of the columns are non-null.
This PR fixes that technical debt. Count now takes multiple columns as its input. In order to make this work I have also added support for multiple columns in the single distinct code path.
cc yhuai
Author: Herman van Hovell <hvanhovell@questtec.nl>
Closes #10015 from hvanhovell/SPARK-12024.
Diffstat (limited to 'sql/catalyst/src')
4 files changed, 12 insertions, 64 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 09a1da9200..441f52ab5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -21,8 +21,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -case class Count(child: Expression) extends DeclarativeAggregate { - override def children: Seq[Expression] = child :: Nil +case class Count(children: Seq[Expression]) extends DeclarativeAggregate { override def nullable: Boolean = false @@ -30,7 +29,7 @@ case class Count(child: Expression) extends DeclarativeAggregate { override def dataType: DataType = LongType // Expected input data type. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType) private lazy val count = AttributeReference("count", LongType)() @@ -41,7 +40,7 @@ case class Count(child: Expression) extends DeclarativeAggregate { ) override lazy val updateExpressions = Seq( - /* count = */ If(IsNull(child), count, count + 1L) + /* count = */ If(children.map(IsNull).reduce(Or), count, count + 1L) ) override lazy val mergeExpressions = Seq( @@ -54,17 +53,5 @@ case class Count(child: Expression) extends DeclarativeAggregate { } object Count { - def apply(children: Seq[Expression]): Count = { - // This is used to deal with COUNT DISTINCT. When we have multiple - // children (COUNT(DISTINCT col1, col2, ...)), we wrap them in a STRUCT (i.e. a Row). - // Also, the semantic of COUNT(DISTINCT col1, col2, ...) is that if there is any - // null in the arguments, we will not count that row. So, we use DropAnyNull at here - // to return a null when any field of the created STRUCT is null. - val child = if (children.size > 1) { - DropAnyNull(CreateStruct(children)) - } else { - children.head - } - Count(child) - } + def apply(child: Expression): Count = Count(child :: Nil) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 694a2a7c54..40b1eec63e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -426,30 +426,3 @@ case class Greatest(children: Seq[Expression]) extends Expression { } } -/** Operator that drops a row when it contains any nulls. */ -case class DropAnyNull(child: Expression) extends UnaryExpression with ExpectsInputTypes { - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(StructType) - - protected override def nullSafeEval(input: Any): InternalRow = { - val row = input.asInstanceOf[InternalRow] - if (row.anyNull) { - null - } else { - row - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, eval => { - s""" - if ($eval.anyNull()) { - ${ev.isNull} = true; - } else { - ${ev.value} = $eval; - } - """ - }) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2901d8f2ef..06d14fcf8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -362,9 +362,14 @@ object LikeSimplification extends Rule[LogicalPlan] { * Null value propagation from bottom to top of the expression tree. */ object NullPropagation extends Rule[LogicalPlan] { + def nonNullLiteral(e: Expression): Boolean = e match { + case Literal(null, _) => false + case _ => true + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case e @ AggregateExpression(Count(Literal(null, _)), _, _) => + case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) => Cast(Literal(0L), e.dataType) case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) @@ -377,16 +382,13 @@ object NullPropagation extends Rule[LogicalPlan] { Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) - case e @ AggregateExpression(Count(expr), mode, false) if !expr.nullable => + case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable) => // This rule should be only triggered when isDistinct field is false. AggregateExpression(Count(Literal(1)), mode, isDistinct = false) // For Coalesce, remove null literals. case e @ Coalesce(children) => - val newChildren = children.filter { - case Literal(null, _) => false - case _ => true - } + val newChildren = children.filter(nonNullLiteral) if (newChildren.length == 0) { Literal.create(null, e.dataType) } else if (newChildren.length == 1) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index c1e3c17b87..0df673bb9f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -231,18 +231,4 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2) } } - - test("function dropAnyNull") { - val drop = DropAnyNull(CreateStruct(Seq('a.string.at(0), 'b.string.at(1)))) - val a = create_row("a", "q") - val nullStr: String = null - checkEvaluation(drop, a, a) - checkEvaluation(drop, null, create_row("b", nullStr)) - checkEvaluation(drop, null, create_row(nullStr, nullStr)) - - val row = 'r.struct( - StructField("a", StringType, false), - StructField("b", StringType, true)).at(0) - checkEvaluation(DropAnyNull(row), null, create_row(null)) - } } |