aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@questtec.nl>2015-11-08 11:06:10 -0800
committerYin Huai <yhuai@databricks.com>2015-11-08 11:06:10 -0800
commit30c8ba71a76788cbc6916bc1ba6bc8522925fc2b (patch)
tree851dbdcce7d78bbf6fd4c948dd4407642fea63cc
parent5c4e6d7ec9157c02494a382dfb49e7fbde3be222 (diff)
downloadspark-30c8ba71a76788cbc6916bc1ba6bc8522925fc2b.tar.gz
spark-30c8ba71a76788cbc6916bc1ba6bc8522925fc2b.tar.bz2
spark-30c8ba71a76788cbc6916bc1ba6bc8522925fc2b.zip
[SPARK-11451][SQL] Support single distinct count on multiple columns.
This PR adds support for multiple column in a single count distinct aggregate to the new aggregation path. cc yhuai Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #9409 from hvanhovell/SPARK-11451.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala44
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala30
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala25
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala37
6 files changed, 127 insertions, 26 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
index ac23f72782..9b22ce2619 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
@@ -22,26 +22,27 @@ import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Expand, Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.types.{IntegerType, StructType, MapType, ArrayType}
+import org.apache.spark.sql.types._
/**
* Utility functions used by the query planner to convert our plan to new aggregation code path.
*/
object Utils {
- // Right now, we do not support complex types in the grouping key schema.
- private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = {
- val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists {
- case array: ArrayType => true
- case map: MapType => true
- case struct: StructType => true
- case _ => false
- }
- !hasComplexTypes
+ // Check if the DataType given cannot be part of a group by clause.
+ private def isUnGroupable(dt: DataType): Boolean = dt match {
+ case _: ArrayType | _: MapType => true
+ case s: StructType => s.fields.exists(f => isUnGroupable(f.dataType))
+ case _ => false
}
+ // Right now, we do not support complex types in the grouping key schema.
+ private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean =
+ !aggregate.groupingExpressions.exists(e => isUnGroupable(e.dataType))
+
private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
case p: Aggregate if supportsGroupingKeySchema(p) =>
+
val converted = MultipleDistinctRewriter.rewrite(p.transformExpressionsDown {
case expressions.Average(child) =>
aggregate.AggregateExpression2(
@@ -55,10 +56,14 @@ object Utils {
mode = aggregate.Complete,
isDistinct = false)
- // We do not support multiple COUNT DISTINCT columns for now.
- case expressions.CountDistinct(children) if children.length == 1 =>
+ case expressions.CountDistinct(children) =>
+ val child = if (children.size > 1) {
+ DropAnyNull(CreateStruct(children))
+ } else {
+ children.head
+ }
aggregate.AggregateExpression2(
- aggregateFunction = aggregate.Count(children.head),
+ aggregateFunction = aggregate.Count(child),
mode = aggregate.Complete,
isDistinct = true)
@@ -320,7 +325,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
val gid = new AttributeReference("gid", IntegerType, false)()
val groupByMap = a.groupingExpressions.collect {
case ne: NamedExpression => ne -> ne.toAttribute
- case e => e -> new AttributeReference(e.prettyName, e.dataType, e.nullable)()
+ case e => e -> new AttributeReference(e.prettyString, e.dataType, e.nullable)()
}
val groupByAttrs = groupByMap.map(_._2)
@@ -365,14 +370,15 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
// Setup expand for the 'regular' aggregate expressions.
val regularAggExprs = aggExpressions.filter(!_.isDistinct)
val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct
- val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair).toMap
+ val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair)
// Setup aggregates for 'regular' aggregate expressions.
val regularGroupId = Literal(0)
+ val regularAggChildAttrLookup = regularAggChildAttrMap.toMap
val regularAggOperatorMap = regularAggExprs.map { e =>
// Perform the actual aggregation in the initial aggregate.
- val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrMap)
- val operator = Alias(e.copy(aggregateFunction = af), e.toString)()
+ val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup)
+ val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)()
// Select the result of the first aggregate in the last aggregate.
val result = AggregateExpression2(
@@ -416,7 +422,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
// Construct the expand operator.
val expand = Expand(
regularAggProjection ++ distinctAggProjections,
- groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.values.toSeq,
+ groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2),
a.child)
// Construct the first aggregate operator. This de-duplicates the all the children of
@@ -457,5 +463,5 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
// NamedExpression. This is done to prevent collisions between distinct and regular aggregate
// children, in this case attribute reuse causes the input of the regular aggregate to bound to
// the (nulled out) input of the distinct aggregate.
- e -> new AttributeReference(e.prettyName, e.dataType, true)()
+ e -> new AttributeReference(e.prettyString, e.dataType, true)()
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index d532629984..0d4af43978 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.TypeUtils
-import org.apache.spark.sql.types.{NullType, BooleanType, DataType}
+import org.apache.spark.sql.types._
case class If(predicate: Expression, trueValue: Expression, falseValue: Expression)
@@ -419,3 +419,31 @@ case class Greatest(children: Seq[Expression]) extends Expression {
"""
}
}
+
+/** Operator that drops a row when it contains any nulls. */
+case class DropAnyNull(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+ override def nullable: Boolean = true
+ override def dataType: DataType = child.dataType
+ override def inputTypes: Seq[AbstractDataType] = Seq(StructType)
+
+ protected override def nullSafeEval(input: Any): InternalRow = {
+ val row = input.asInstanceOf[InternalRow]
+ if (row.anyNull) {
+ null
+ } else {
+ row
+ }
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ nullSafeCodeGen(ctx, ev, eval => {
+ s"""
+ if ($eval.anyNull()) {
+ ${ev.isNull} = true;
+ } else {
+ ${ev.value} = $eval;
+ }
+ """
+ })
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index fb963e2f8f..09aac00a45 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -306,6 +306,9 @@ case class Expand(
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
+ override def references: AttributeSet =
+ AttributeSet(projections.flatten.flatMap(_.references))
+
override def statistics: Statistics = {
// TODO shouldn't we factor in the size of the projection versus the size of the backing child
// row?
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
index 0df673bb9f..c1e3c17b87 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
@@ -231,4 +231,18 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2)
}
}
+
+ test("function dropAnyNull") {
+ val drop = DropAnyNull(CreateStruct(Seq('a.string.at(0), 'b.string.at(1))))
+ val a = create_row("a", "q")
+ val nullStr: String = null
+ checkEvaluation(drop, a, a)
+ checkEvaluation(drop, null, create_row("b", nullStr))
+ checkEvaluation(drop, null, create_row(nullStr, nullStr))
+
+ val row = 'r.struct(
+ StructField("a", StringType, false),
+ StructField("b", StringType, true)).at(0)
+ checkEvaluation(DropAnyNull(row), null, create_row(null))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 2e679e7bc4..eb1ee266c5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -162,6 +162,31 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
)
}
+ test("multiple column distinct count") {
+ val df1 = Seq(
+ ("a", "b", "c"),
+ ("a", "b", "c"),
+ ("a", "b", "d"),
+ ("x", "y", "z"),
+ ("x", "q", null.asInstanceOf[String]))
+ .toDF("key1", "key2", "key3")
+
+ checkAnswer(
+ df1.agg(countDistinct('key1, 'key2)),
+ Row(3)
+ )
+
+ checkAnswer(
+ df1.agg(countDistinct('key1, 'key2, 'key3)),
+ Row(3)
+ )
+
+ checkAnswer(
+ df1.groupBy('key1).agg(countDistinct('key2, 'key3)),
+ Seq(Row("a", 2), Row("x", 1))
+ )
+ }
+
test("zero count") {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 7f6fe33923..ea36c132bb 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -516,21 +516,46 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(3, 4, 4, 3, null) :: Nil)
}
- test("multiple distinct column sets") {
+ test("single distinct multiple columns set") {
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | key,
+ | count(distinct value1, value2)
+ |FROM agg2
+ |GROUP BY key
+ """.stripMargin),
+ Row(null, 3) ::
+ Row(1, 3) ::
+ Row(2, 1) ::
+ Row(3, 0) :: Nil)
+ }
+
+ test("multiple distinct multiple columns sets") {
checkAnswer(
sqlContext.sql(
"""
|SELECT
| key,
| count(distinct value1),
- | count(distinct value2)
+ | sum(distinct value1),
+ | count(distinct value2),
+ | sum(distinct value2),
+ | count(distinct value1, value2),
+ | count(value1),
+ | sum(value1),
+ | count(value2),
+ | sum(value2),
+ | count(*),
+ | count(1)
|FROM agg2
|GROUP BY key
""".stripMargin),
- Row(null, 3, 3) ::
- Row(1, 2, 3) ::
- Row(2, 2, 1) ::
- Row(3, 0, 1) :: Nil)
+ Row(null, 3, 30, 3, 60, 3, 3, 30, 3, 60, 4, 4) ::
+ Row(1, 2, 40, 3, -10, 3, 3, 70, 3, -10, 3, 3) ::
+ Row(2, 2, 0, 1, 1, 1, 3, 1, 3, 3, 4, 4) ::
+ Row(3, 0, null, 1, 3, 0, 0, null, 1, 3, 2, 2) :: Nil)
}
test("test count") {