aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@questtec.nl>2015-11-29 14:13:11 -0800
committerYin Huai <yhuai@databricks.com>2015-11-29 14:13:11 -0800
commit3d28081e53698ed77e93c04299957c02bcaba9bf (patch)
tree8c4c791b93a06e975a31b140784c38bc6980b303 /sql/catalyst
parentcc7a1bc9370b163f51230e5ca4be612d133a5086 (diff)
downloadspark-3d28081e53698ed77e93c04299957c02bcaba9bf.tar.gz
spark-3d28081e53698ed77e93c04299957c02bcaba9bf.tar.bz2
spark-3d28081e53698ed77e93c04299957c02bcaba9bf.zip
[SPARK-12024][SQL] More efficient multi-column counting.
In https://github.com/apache/spark/pull/9409 we enabled multi-column counting. The approach taken in that PR introduces a bit of overhead by first creating a row only to check if all of the columns are non-null. This PR fixes that technical debt. Count now takes multiple columns as its input. In order to make this work I have also added support for multiple columns in the single distinct code path. cc yhuai Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #10015 from hvanhovell/SPARK-12024.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala21
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala27
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala14
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala14
4 files changed, 12 insertions, 64 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
index 09a1da9200..441f52ab5c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
@@ -21,8 +21,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
-case class Count(child: Expression) extends DeclarativeAggregate {
- override def children: Seq[Expression] = child :: Nil
+case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
override def nullable: Boolean = false
@@ -30,7 +29,7 @@ case class Count(child: Expression) extends DeclarativeAggregate {
override def dataType: DataType = LongType
// Expected input data type.
- override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+ override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType)
private lazy val count = AttributeReference("count", LongType)()
@@ -41,7 +40,7 @@ case class Count(child: Expression) extends DeclarativeAggregate {
)
override lazy val updateExpressions = Seq(
- /* count = */ If(IsNull(child), count, count + 1L)
+ /* count = */ If(children.map(IsNull).reduce(Or), count, count + 1L)
)
override lazy val mergeExpressions = Seq(
@@ -54,17 +53,5 @@ case class Count(child: Expression) extends DeclarativeAggregate {
}
object Count {
- def apply(children: Seq[Expression]): Count = {
- // This is used to deal with COUNT DISTINCT. When we have multiple
- // children (COUNT(DISTINCT col1, col2, ...)), we wrap them in a STRUCT (i.e. a Row).
- // Also, the semantic of COUNT(DISTINCT col1, col2, ...) is that if there is any
- // null in the arguments, we will not count that row. So, we use DropAnyNull at here
- // to return a null when any field of the created STRUCT is null.
- val child = if (children.size > 1) {
- DropAnyNull(CreateStruct(children))
- } else {
- children.head
- }
- Count(child)
- }
+ def apply(child: Expression): Count = Count(child :: Nil)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 694a2a7c54..40b1eec63e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -426,30 +426,3 @@ case class Greatest(children: Seq[Expression]) extends Expression {
}
}
-/** Operator that drops a row when it contains any nulls. */
-case class DropAnyNull(child: Expression) extends UnaryExpression with ExpectsInputTypes {
- override def nullable: Boolean = true
- override def dataType: DataType = child.dataType
- override def inputTypes: Seq[AbstractDataType] = Seq(StructType)
-
- protected override def nullSafeEval(input: Any): InternalRow = {
- val row = input.asInstanceOf[InternalRow]
- if (row.anyNull) {
- null
- } else {
- row
- }
- }
-
- override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- nullSafeCodeGen(ctx, ev, eval => {
- s"""
- if ($eval.anyNull()) {
- ${ev.isNull} = true;
- } else {
- ${ev.value} = $eval;
- }
- """
- })
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 2901d8f2ef..06d14fcf8b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -362,9 +362,14 @@ object LikeSimplification extends Rule[LogicalPlan] {
* Null value propagation from bottom to top of the expression tree.
*/
object NullPropagation extends Rule[LogicalPlan] {
+ def nonNullLiteral(e: Expression): Boolean = e match {
+ case Literal(null, _) => false
+ case _ => true
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
- case e @ AggregateExpression(Count(Literal(null, _)), _, _) =>
+ case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) =>
Cast(Literal(0L), e.dataType)
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
@@ -377,16 +382,13 @@ object NullPropagation extends Rule[LogicalPlan] {
Literal.create(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
- case e @ AggregateExpression(Count(expr), mode, false) if !expr.nullable =>
+ case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable) =>
// This rule should be only triggered when isDistinct field is false.
AggregateExpression(Count(Literal(1)), mode, isDistinct = false)
// For Coalesce, remove null literals.
case e @ Coalesce(children) =>
- val newChildren = children.filter {
- case Literal(null, _) => false
- case _ => true
- }
+ val newChildren = children.filter(nonNullLiteral)
if (newChildren.length == 0) {
Literal.create(null, e.dataType)
} else if (newChildren.length == 1) {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
index c1e3c17b87..0df673bb9f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
@@ -231,18 +231,4 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2)
}
}
-
- test("function dropAnyNull") {
- val drop = DropAnyNull(CreateStruct(Seq('a.string.at(0), 'b.string.at(1))))
- val a = create_row("a", "q")
- val nullStr: String = null
- checkEvaluation(drop, a, a)
- checkEvaluation(drop, null, create_row("b", nullStr))
- checkEvaluation(drop, null, create_row(nullStr, nullStr))
-
- val row = 'r.struct(
- StructField("a", StringType, false),
- StructField("b", StringType, true)).at(0)
- checkEvaluation(DropAnyNull(row), null, create_row(null))
- }
}