diff options
author | Davies Liu <davies@databricks.com> | 2016-01-20 10:02:40 -0800 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-01-20 10:02:40 -0800 |
commit | 8e4f894e986ccd943df9ddf55fc853eb0558886f (patch) | |
tree | 3c8f46759f12afbfe2a5253a75e136074f5febc1 /sql/core | |
parent | 753b1945115245800898959e3ab249a94a1935e9 (diff) | |
download | spark-8e4f894e986ccd943df9ddf55fc853eb0558886f.tar.gz spark-8e4f894e986ccd943df9ddf55fc853eb0558886f.tar.bz2 spark-8e4f894e986ccd943df9ddf55fc853eb0558886f.zip |
[SPARK-12881] [SQL] subexpress elimination in mutable projection
Author: Davies Liu <davies@databricks.com>
Closes #10814 from davies/mutable_subexpr.
Diffstat (limited to 'sql/core')
5 files changed, 22 insertions, 6 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 75101ea0fc..b19b772409 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -196,10 +196,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ private[this] def isTesting: Boolean = sys.props.contains("spark.testing") protected def newMutableProjection( - expressions: Seq[Expression], inputSchema: Seq[Attribute]): () => MutableProjection = { + expressions: Seq[Expression], + inputSchema: Seq[Attribute], + useSubexprElimination: Boolean = false): () => MutableProjection = { log.debug(s"Creating MutableProj: $expressions, inputSchema: $inputSchema") try { - GenerateMutableProjection.generate(expressions, inputSchema) + GenerateMutableProjection.generate(expressions, inputSchema, useSubexprElimination) } catch { case e: Exception => if (isTesting) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 168b5ab031..26a7340f1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -194,7 +194,11 @@ case class Window( val functions = functionSeq.toArray // Construct an aggregate processor if we need one. - def processor = AggregateProcessor(functions, ordinal, child.output, newMutableProjection) + def processor = AggregateProcessor( + functions, + ordinal, + child.output, + (expressions, schema) => newMutableProjection(expressions, schema)) // Create the factory val factory = key match { @@ -206,7 +210,7 @@ case class Window( ordinal, functions, child.output, - newMutableProjection, + (expressions, schema) => newMutableProjection(expressions, schema), offset) // Growing Frame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index 1d56592c40..06a3991459 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -87,7 +87,8 @@ case class SortBasedAggregate( aggregateAttributes, initialInputBufferOffset, resultExpressions, - newMutableProjection, + (expressions, inputSchema) => + newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), numInputRows, numOutputRows) if (!hasInput && groupingExpressions.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index a9cf04388d..8dcbab4c8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -94,7 +94,8 @@ case class TungstenAggregate( aggregateAttributes, initialInputBufferOffset, resultExpressions, - newMutableProjection, + (expressions, inputSchema) => + newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), child.output, iter, testFallbackStartsAt, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d7f182352b..b159346bed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.parser.ParserConf import org.apache.spark.sql.execution.{aggregate, SparkQl} import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin} +import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} import org.apache.spark.sql.test.SQLTestData._ @@ -1968,6 +1969,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { verifyCallCount( df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) + val testUdf = functions.udf((x: Int) => { + countAcc.++=(1) + x + }) + verifyCallCount( + df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) + // Would be nice if semantic equals for `+` understood commutative verifyCallCount( df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) |