aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-01-20 10:02:40 -0800
committerDavies Liu <davies.liu@gmail.com>2016-01-20 10:02:40 -0800
commit8e4f894e986ccd943df9ddf55fc853eb0558886f (patch)
tree3c8f46759f12afbfe2a5253a75e136074f5febc1 /sql/core
parent753b1945115245800898959e3ab249a94a1935e9 (diff)
downloadspark-8e4f894e986ccd943df9ddf55fc853eb0558886f.tar.gz
spark-8e4f894e986ccd943df9ddf55fc853eb0558886f.tar.bz2
spark-8e4f894e986ccd943df9ddf55fc853eb0558886f.zip
[SPARK-12881] [SQL] subexpress elimination in mutable projection
Author: Davies Liu <davies@databricks.com> Closes #10814 from davies/mutable_subexpr.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala8
5 files changed, 22 insertions, 6 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 75101ea0fc..b19b772409 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -196,10 +196,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
private[this] def isTesting: Boolean = sys.props.contains("spark.testing")
protected def newMutableProjection(
- expressions: Seq[Expression], inputSchema: Seq[Attribute]): () => MutableProjection = {
+ expressions: Seq[Expression],
+ inputSchema: Seq[Attribute],
+ useSubexprElimination: Boolean = false): () => MutableProjection = {
log.debug(s"Creating MutableProj: $expressions, inputSchema: $inputSchema")
try {
- GenerateMutableProjection.generate(expressions, inputSchema)
+ GenerateMutableProjection.generate(expressions, inputSchema, useSubexprElimination)
} catch {
case e: Exception =>
if (isTesting) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index 168b5ab031..26a7340f1a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -194,7 +194,11 @@ case class Window(
val functions = functionSeq.toArray
// Construct an aggregate processor if we need one.
- def processor = AggregateProcessor(functions, ordinal, child.output, newMutableProjection)
+ def processor = AggregateProcessor(
+ functions,
+ ordinal,
+ child.output,
+ (expressions, schema) => newMutableProjection(expressions, schema))
// Create the factory
val factory = key match {
@@ -206,7 +210,7 @@ case class Window(
ordinal,
functions,
child.output,
- newMutableProjection,
+ (expressions, schema) => newMutableProjection(expressions, schema),
offset)
// Growing Frame.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
index 1d56592c40..06a3991459 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
@@ -87,7 +87,8 @@ case class SortBasedAggregate(
aggregateAttributes,
initialInputBufferOffset,
resultExpressions,
- newMutableProjection,
+ (expressions, inputSchema) =>
+ newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled),
numInputRows,
numOutputRows)
if (!hasInput && groupingExpressions.isEmpty) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index a9cf04388d..8dcbab4c8c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -94,7 +94,8 @@ case class TungstenAggregate(
aggregateAttributes,
initialInputBufferOffset,
resultExpressions,
- newMutableProjection,
+ (expressions, inputSchema) =>
+ newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled),
child.output,
iter,
testFallbackStartsAt,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index d7f182352b..b159346bed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.parser.ParserConf
import org.apache.spark.sql.execution.{aggregate, SparkQl}
import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin}
+import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext}
import org.apache.spark.sql.test.SQLTestData._
@@ -1968,6 +1969,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
verifyCallCount(
df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2)
+ val testUdf = functions.udf((x: Int) => {
+ countAcc.++=(1)
+ x
+ })
+ verifyCallCount(
+ df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1)
+
// Would be nice if semantic equals for `+` understood commutative
verifyCallCount(
df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2)