aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-01-29 01:59:59 -0800
committerReynold Xin <rxin@databricks.com>2016-01-29 01:59:59 -0800
commit55561e7693dd2a5bf3c7f8026c725421801fd0ec (patch)
treeaf8a0fbac09a1168e1a4088428a68ed7a5ebd66c
parent8d3cc3de7d116190911e7943ef3233fe3b7db1bf (diff)
downloadspark-55561e7693dd2a5bf3c7f8026c725421801fd0ec.tar.gz
spark-55561e7693dd2a5bf3c7f8026c725421801fd0ec.tar.bz2
spark-55561e7693dd2a5bf3c7f8026c725421801fd0ec.zip
[SPARK-13031][SQL] cleanup codegen and improve test coverage
1. enable whole stage codegen during tests even there is only one operator supports that. 2. split doProduce() into two APIs: upstream() and doProduce() 3. generate prefix for fresh names of each operator 4. pass UnsafeRow to parent directly (avoid getters and create UnsafeRow again) 5. fix bugs and tests. This PR re-open #10944 and fix the bug. Author: Davies Liu <davies@databricks.com> Closes #10977 from davies/gen_refactor.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala188
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala98
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala96
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala103
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala34
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala10
11 files changed, 350 insertions, 205 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 2747c315ad..e6704cf8bb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -145,13 +145,22 @@ class CodegenContext {
private val curId = new java.util.concurrent.atomic.AtomicInteger()
/**
+ * A prefix used to generate fresh name.
+ */
+ var freshNamePrefix = ""
+
+ /**
* Returns a term name that is unique within this instance of a `CodeGenerator`.
*
* (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
* function.)
*/
- def freshName(prefix: String): String = {
- s"$prefix${curId.getAndIncrement}"
+ def freshName(name: String): String = {
+ if (freshNamePrefix == "") {
+ s"$name${curId.getAndIncrement}"
+ } else {
+ s"${freshNamePrefix}_$name${curId.getAndIncrement}"
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index d9fe76133c..ec31db19b9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -93,7 +93,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
// Can't call setNullAt on DecimalType, because we need to keep the offset
s"""
if (this.isNull_$i) {
- ${ctx.setColumn("mutableRow", e.dataType, i, null)};
+ ${ctx.setColumn("mutableRow", e.dataType, i, "null")};
} else {
${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 57f4945de9..ef81ba60f0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -22,9 +22,11 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.util.Utils
/**
* An interface for those physical operators that support codegen.
@@ -42,10 +44,16 @@ trait CodegenSupport extends SparkPlan {
private var parent: CodegenSupport = null
/**
- * Returns an input RDD of InternalRow and Java source code to process them.
+ * Returns the RDD of InternalRow which generates the input rows.
*/
- def produce(ctx: CodegenContext, parent: CodegenSupport): (RDD[InternalRow], String) = {
+ def upstream(): RDD[InternalRow]
+
+ /**
+ * Returns Java source code to process the rows from upstream.
+ */
+ def produce(ctx: CodegenContext, parent: CodegenSupport): String = {
this.parent = parent
+ ctx.freshNamePrefix = nodeName
doProduce(ctx)
}
@@ -66,16 +74,41 @@ trait CodegenSupport extends SparkPlan {
* # call consume(), wich will call parent.doConsume()
* }
*/
- protected def doProduce(ctx: CodegenContext): (RDD[InternalRow], String)
+ protected def doProduce(ctx: CodegenContext): String
/**
- * Consume the columns generated from current SparkPlan, call it's parent or create an iterator.
+ * Consume the columns generated from current SparkPlan, call it's parent.
*/
- protected def consume(ctx: CodegenContext, columns: Seq[ExprCode]): String = {
- assert(columns.length == output.length)
- parent.doConsume(ctx, this, columns)
+ def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = {
+ if (input != null) {
+ assert(input.length == output.length)
+ }
+ parent.consumeChild(ctx, this, input, row)
}
+ /**
+ * Consume the columns generated from it's child, call doConsume() or emit the rows.
+ */
+ def consumeChild(
+ ctx: CodegenContext,
+ child: SparkPlan,
+ input: Seq[ExprCode],
+ row: String = null): String = {
+ ctx.freshNamePrefix = nodeName
+ if (row != null) {
+ ctx.currentVars = null
+ ctx.INPUT_ROW = row
+ val evals = child.output.zipWithIndex.map { case (attr, i) =>
+ BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
+ }
+ s"""
+ | ${evals.map(_.code).mkString("\n")}
+ | ${doConsume(ctx, evals)}
+ """.stripMargin
+ } else {
+ doConsume(ctx, input)
+ }
+ }
/**
* Generate the Java source code to process the rows from child SparkPlan.
@@ -89,7 +122,9 @@ trait CodegenSupport extends SparkPlan {
* # call consume(), which will call parent.doConsume()
* }
*/
- def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String
+ protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ throw new UnsupportedOperationException
+ }
}
@@ -102,31 +137,36 @@ trait CodegenSupport extends SparkPlan {
case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
override def output: Seq[Attribute] = child.output
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+
+ override def doPrepare(): Unit = {
+ child.prepare()
+ }
- override def supportCodegen: Boolean = true
+ override def doExecute(): RDD[InternalRow] = {
+ child.execute()
+ }
- override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+ override def supportCodegen: Boolean = false
+
+ override def upstream(): RDD[InternalRow] = {
+ child.execute()
+ }
+
+ override def doProduce(ctx: CodegenContext): String = {
val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
val row = ctx.freshName("row")
ctx.INPUT_ROW = row
ctx.currentVars = null
val columns = exprs.map(_.gen(ctx))
- val code = s"""
- | while (input.hasNext()) {
+ s"""
+ | while (input.hasNext()) {
| InternalRow $row = (InternalRow) input.next();
| ${columns.map(_.code).mkString("\n")}
| ${consume(ctx, columns)}
| }
""".stripMargin
- (child.execute(), code)
- }
-
- def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
- throw new UnsupportedOperationException
- }
-
- override def doExecute(): RDD[InternalRow] = {
- throw new UnsupportedOperationException
}
override def simpleString: String = "INPUT"
@@ -143,16 +183,20 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
*
* -> execute()
* |
- * doExecute() --------> produce()
+ * doExecute() ---------> upstream() -------> upstream() ------> execute()
+ * |
+ * -----------------> produce()
* |
* doProduce() -------> produce()
* |
- * doProduce() ---> execute()
+ * doProduce()
* |
* consume()
- * doConsume() ------------|
+ * consumeChild() <-----------|
* |
- * doConsume() <----- consume()
+ * doConsume()
+ * |
+ * consumeChild() <----- consume()
*
* SparkPlan A should override doProduce() and doConsume().
*
@@ -162,37 +206,48 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
extends SparkPlan with CodegenSupport {
+ override def supportCodegen: Boolean = false
+
override def output: Seq[Attribute] = plan.output
+ override def outputPartitioning: Partitioning = plan.outputPartitioning
+ override def outputOrdering: Seq[SortOrder] = plan.outputOrdering
+
+ override def doPrepare(): Unit = {
+ plan.prepare()
+ }
override def doExecute(): RDD[InternalRow] = {
val ctx = new CodegenContext
- val (rdd, code) = plan.produce(ctx, this)
+ val code = plan.produce(ctx, this)
val references = ctx.references.toArray
val source = s"""
public Object generate(Object[] references) {
- return new GeneratedIterator(references);
+ return new GeneratedIterator(references);
}
class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
- private Object[] references;
- ${ctx.declareMutableStates()}
+ private Object[] references;
+ ${ctx.declareMutableStates()}
+ ${ctx.declareAddedFunctions()}
- public GeneratedIterator(Object[] references) {
+ public GeneratedIterator(Object[] references) {
this.references = references;
${ctx.initMutableStates()}
- }
+ }
- protected void processNext() {
+ protected void processNext() throws java.io.IOException {
$code
- }
+ }
}
- """
+ """
+
// try to compile, helpful for debug
// println(s"${CodeFormatter.format(source)}")
CodeGenerator.compile(source)
- rdd.mapPartitions { iter =>
+ plan.upstream().mapPartitions { iter =>
+
val clazz = CodeGenerator.compile(source)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
buffer.setInput(iter)
@@ -203,29 +258,47 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
}
}
- override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+ override def upstream(): RDD[InternalRow] = {
throw new UnsupportedOperationException
}
- override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
- if (input.nonEmpty) {
- val colExprs = output.zipWithIndex.map { case (attr, i) =>
- BoundReference(i, attr.dataType, attr.nullable)
- }
- // generate the code to create a UnsafeRow
- ctx.currentVars = input
- val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
- s"""
- | ${code.code.trim}
- | currentRow = ${code.value};
- | return;
- """.stripMargin
- } else {
- // There is no columns
+ override def doProduce(ctx: CodegenContext): String = {
+ throw new UnsupportedOperationException
+ }
+
+ override def consumeChild(
+ ctx: CodegenContext,
+ child: SparkPlan,
+ input: Seq[ExprCode],
+ row: String = null): String = {
+
+ if (row != null) {
+ // There is an UnsafeRow already
s"""
- | currentRow = unsafeRow;
+ | currentRow = $row;
| return;
""".stripMargin
+ } else {
+ assert(input != null)
+ if (input.nonEmpty) {
+ val colExprs = output.zipWithIndex.map { case (attr, i) =>
+ BoundReference(i, attr.dataType, attr.nullable)
+ }
+ // generate the code to create a UnsafeRow
+ ctx.currentVars = input
+ val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
+ s"""
+ | ${code.code.trim}
+ | currentRow = ${code.value};
+ | return;
+ """.stripMargin
+ } else {
+ // There is no columns
+ s"""
+ | currentRow = unsafeRow;
+ | return;
+ """.stripMargin
+ }
}
}
@@ -246,7 +319,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
builder.append(simpleString)
builder.append("\n")
- plan.generateTreeString(depth + 1, lastChildren :+children.isEmpty :+ true, builder)
+ plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder)
if (children.nonEmpty) {
children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
children.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
@@ -286,13 +359,14 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
case plan: CodegenSupport if supportCodegen(plan) &&
// Whole stage codegen is only useful when there are at least two levels of operators that
// support it (save at least one projection/iterator).
- plan.children.exists(supportCodegen) =>
+ (Utils.isTesting || plan.children.exists(supportCodegen)) =>
var inputs = ArrayBuffer[SparkPlan]()
val combined = plan.transform {
case p if !supportCodegen(p) =>
- inputs += p
- InputAdapter(p)
+ val input = apply(p) // collapse them recursively
+ inputs += input
+ InputAdapter(input)
}.asInstanceOf[CodegenSupport]
WholeStageCodegen(combined, inputs)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
index 0c74df0aa5..38da82c47c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -238,7 +238,7 @@ abstract class AggregationIterator(
resultProjection(joinedRow(currentGroupingKey, currentBuffer))
}
} else {
- // Grouping-only: we only output values of grouping expressions.
+ // Grouping-only: we only output values based on grouping expressions.
val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes)
(currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => {
resultProjection(currentGroupingKey)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 23e54f344d..ff2f38bfd9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -117,9 +117,7 @@ case class TungstenAggregate(
override def supportCodegen: Boolean = {
groupingExpressions.isEmpty &&
// ImperativeAggregate is not supported right now
- !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) &&
- // final aggregation only have one row, do not need to codegen
- !aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete)
+ !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
}
// The variables used as aggregation buffer
@@ -127,7 +125,11 @@ case class TungstenAggregate(
private val modes = aggregateExpressions.map(_.mode).distinct
- protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+ override def upstream(): RDD[InternalRow] = {
+ child.asInstanceOf[CodegenSupport].upstream()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
val initAgg = ctx.freshName("initAgg")
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
@@ -137,60 +139,96 @@ case class TungstenAggregate(
bufVars = initExpr.map { e =>
val isNull = ctx.freshName("bufIsNull")
val value = ctx.freshName("bufValue")
+ ctx.addMutableState("boolean", isNull, "")
+ ctx.addMutableState(ctx.javaType(e.dataType), value, "")
// The initial expression should not access any column
val ev = e.gen(ctx)
val initVars = s"""
- | boolean $isNull = ${ev.isNull};
- | ${ctx.javaType(e.dataType)} $value = ${ev.value};
+ | $isNull = ${ev.isNull};
+ | $value = ${ev.value};
""".stripMargin
ExprCode(ev.code + initVars, isNull, value)
}
- val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this)
- val source =
+ // generate variables for output
+ val bufferAttrs = functions.flatMap(_.aggBufferAttributes)
+ val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) {
+ // evaluate aggregate results
+ ctx.currentVars = bufVars
+ val aggResults = functions.map(_.evaluateExpression).map { e =>
+ BindReferences.bindReference(e, bufferAttrs).gen(ctx)
+ }
+ // evaluate result expressions
+ ctx.currentVars = aggResults
+ val resultVars = resultExpressions.map { e =>
+ BindReferences.bindReference(e, aggregateAttributes).gen(ctx)
+ }
+ (resultVars, s"""
+ | ${aggResults.map(_.code).mkString("\n")}
+ | ${resultVars.map(_.code).mkString("\n")}
+ """.stripMargin)
+ } else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
+ // output the aggregate buffer directly
+ (bufVars, "")
+ } else {
+ // no aggregate function, the result should be literals
+ val resultVars = resultExpressions.map(_.gen(ctx))
+ (resultVars, resultVars.map(_.code).mkString("\n"))
+ }
+
+ val doAgg = ctx.freshName("doAgg")
+ ctx.addNewFunction(doAgg,
s"""
- | if (!$initAgg) {
- | $initAgg = true;
- |
+ | private void $doAgg() {
| // initialize aggregation buffer
| ${bufVars.map(_.code).mkString("\n")}
|
- | $childSource
- |
- | // output the result
- | ${consume(ctx, bufVars)}
+ | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
| }
- """.stripMargin
+ """.stripMargin)
- (rdd, source)
+ s"""
+ | if (!$initAgg) {
+ | $initAgg = true;
+ | $doAgg();
+ |
+ | // output the result
+ | $genResult
+ |
+ | ${consume(ctx, resultVars)}
+ | }
+ """.stripMargin
}
- override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
// only have DeclarativeAggregate
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
- // the mode could be only Partial or PartialMerge
- val updateExpr = if (modes.contains(Partial)) {
- functions.flatMap(_.updateExpressions)
- } else {
- functions.flatMap(_.mergeExpressions)
+ val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output
+ val updateExpr = aggregateExpressions.flatMap { e =>
+ e.mode match {
+ case Partial | Complete =>
+ e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
+ case PartialMerge | Final =>
+ e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
+ }
}
- val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output
- val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr))
ctx.currentVars = bufVars ++ input
// TODO: support subexpression elimination
- val codes = boundExpr.zipWithIndex.map { case (e, i) =>
- val ev = e.gen(ctx)
+ val aggVals = updateExpr.map(BindReferences.bindReference(_, inputAttrs).gen(ctx))
+ // aggregate buffer should be updated atomic
+ val updates = aggVals.zipWithIndex.map { case (ev, i) =>
s"""
- | ${ev.code}
| ${bufVars(i).isNull} = ${ev.isNull};
| ${bufVars(i).value} = ${ev.value};
""".stripMargin
}
s"""
- | // do aggregate and update aggregation buffer
- | ${codes.mkString("")}
+ | // do aggregate
+ | ${aggVals.map(_.code).mkString("\n")}
+ | // update aggregation buffer
+ | ${updates.mkString("")}
""".stripMargin
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 6deb72adad..e7a73d5fbb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -37,11 +37,15 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
- protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+ override def upstream(): RDD[InternalRow] = {
+ child.asInstanceOf[CodegenSupport].upstream()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
- override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val exprs = projectList.map(x =>
ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output)))
ctx.currentVars = input
@@ -76,11 +80,15 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
"numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
- protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+ override def upstream(): RDD[InternalRow] = {
+ child.asInstanceOf[CodegenSupport].upstream()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
- override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val expr = ExpressionCanonicalizer.execute(
BindReferences.bindReference(condition, child.output))
ctx.currentVars = input
@@ -153,17 +161,21 @@ case class Range(
output: Seq[Attribute])
extends LeafNode with CodegenSupport {
- protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
- val initTerm = ctx.freshName("range_initRange")
+ override def upstream(): RDD[InternalRow] = {
+ sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i))
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
+ val initTerm = ctx.freshName("initRange")
ctx.addMutableState("boolean", initTerm, s"$initTerm = false;")
- val partitionEnd = ctx.freshName("range_partitionEnd")
+ val partitionEnd = ctx.freshName("partitionEnd")
ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;")
- val number = ctx.freshName("range_number")
+ val number = ctx.freshName("number")
ctx.addMutableState("long", number, s"$number = 0L;")
- val overflow = ctx.freshName("range_overflow")
+ val overflow = ctx.freshName("overflow")
ctx.addMutableState("boolean", overflow, s"$overflow = false;")
- val value = ctx.freshName("range_value")
+ val value = ctx.freshName("value")
val ev = ExprCode("", "false", value)
val BigInt = classOf[java.math.BigInteger].getName
val checkEnd = if (step > 0) {
@@ -172,38 +184,42 @@ case class Range(
s"$number > $partitionEnd"
}
- val rdd = sqlContext.sparkContext.parallelize(0 until numSlices, numSlices)
- .map(i => InternalRow(i))
+ ctx.addNewFunction("initRange",
+ s"""
+ | private void initRange(int idx) {
+ | $BigInt index = $BigInt.valueOf(idx);
+ | $BigInt numSlice = $BigInt.valueOf(${numSlices}L);
+ | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L);
+ | $BigInt step = $BigInt.valueOf(${step}L);
+ | $BigInt start = $BigInt.valueOf(${start}L);
+ |
+ | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
+ | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
+ | $number = Long.MAX_VALUE;
+ | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
+ | $number = Long.MIN_VALUE;
+ | } else {
+ | $number = st.longValue();
+ | }
+ |
+ | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
+ | .multiply(step).add(start);
+ | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
+ | $partitionEnd = Long.MAX_VALUE;
+ | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
+ | $partitionEnd = Long.MIN_VALUE;
+ | } else {
+ | $partitionEnd = end.longValue();
+ | }
+ | }
+ """.stripMargin)
- val code = s"""
+ s"""
| // initialize Range
| if (!$initTerm) {
| $initTerm = true;
| if (input.hasNext()) {
- | $BigInt index = $BigInt.valueOf(((InternalRow) input.next()).getInt(0));
- | $BigInt numSlice = $BigInt.valueOf(${numSlices}L);
- | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L);
- | $BigInt step = $BigInt.valueOf(${step}L);
- | $BigInt start = $BigInt.valueOf(${start}L);
- |
- | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
- | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
- | $number = Long.MAX_VALUE;
- | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
- | $number = Long.MIN_VALUE;
- | } else {
- | $number = st.longValue();
- | }
- |
- | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
- | .multiply(step).add(start);
- | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
- | $partitionEnd = Long.MAX_VALUE;
- | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
- | $partitionEnd = Long.MIN_VALUE;
- | } else {
- | $partitionEnd = end.longValue();
- | }
+ | initRange(((InternalRow) input.next()).getInt(0));
| } else {
| return;
| }
@@ -218,12 +234,6 @@ case class Range(
| ${consume(ctx, Seq(ev))}
| }
""".stripMargin
-
- (rdd, code)
- }
-
- def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
- throw new UnsupportedOperationException
}
protected override def doExecute(): RDD[InternalRow] = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index b1004bc5bc..08fb7c9d84 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -153,6 +153,13 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
)
}
+ test("agg without groups and functions") {
+ checkAnswer(
+ testData2.agg(lit(1)),
+ Row(1)
+ )
+ }
+
test("average") {
checkAnswer(
testData2.agg(avg('a), mean('a)),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 989cb29429..51a50c1fa3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1939,58 +1939,61 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("Common subexpression elimination") {
- // select from a table to prevent constant folding.
- val df = sql("SELECT a, b from testData2 limit 1")
- checkAnswer(df, Row(1, 1))
-
- checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2))
- checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3))
-
- // This does not work because the expressions get grouped like (a + a) + 1
- checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3))
- checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3))
-
- // Identity udf that tracks the number of times it is called.
- val countAcc = sparkContext.accumulator(0, "CallCount")
- sqlContext.udf.register("testUdf", (x: Int) => {
- countAcc.++=(1)
- x
- })
+ // TODO: support subexpression elimination in whole stage codegen
+ withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
+ // select from a table to prevent constant folding.
+ val df = sql("SELECT a, b from testData2 limit 1")
+ checkAnswer(df, Row(1, 1))
+
+ checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2))
+ checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3))
+
+ // This does not work because the expressions get grouped like (a + a) + 1
+ checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3))
+ checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3))
+
+ // Identity udf that tracks the number of times it is called.
+ val countAcc = sparkContext.accumulator(0, "CallCount")
+ sqlContext.udf.register("testUdf", (x: Int) => {
+ countAcc.++=(1)
+ x
+ })
+
+ // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value
+ // is correct.
+ def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = {
+ countAcc.setValue(0)
+ checkAnswer(df, expectedResult)
+ assert(countAcc.value == expectedCount)
+ }
- // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value
- // is correct.
- def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = {
- countAcc.setValue(0)
- checkAnswer(df, expectedResult)
- assert(countAcc.value == expectedCount)
+ verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1)
+ verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
+ verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1)
+ verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2)
+ verifyCallCount(
+ df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1)
+
+ verifyCallCount(
+ df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2)
+
+ val testUdf = functions.udf((x: Int) => {
+ countAcc.++=(1)
+ x
+ })
+ verifyCallCount(
+ df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1)
+
+ // Would be nice if semantic equals for `+` understood commutative
+ verifyCallCount(
+ df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2)
+
+ // Try disabling it via configuration.
+ sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false")
+ verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2)
+ sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true")
+ verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
}
-
- verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1)
- verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
- verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1)
- verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2)
- verifyCallCount(
- df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1)
-
- verifyCallCount(
- df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2)
-
- val testUdf = functions.udf((x: Int) => {
- countAcc.++=(1)
- x
- })
- verifyCallCount(
- df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1)
-
- // Would be nice if semantic equals for `+` understood commutative
- verifyCallCount(
- df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2)
-
- // Try disabling it via configuration.
- sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false")
- verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2)
- sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true")
- verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
}
test("SPARK-10707: nullability should be correctly propagated through set operations (1)") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index cbae19ebd2..82f6811503 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -335,22 +335,24 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
test("save metrics") {
withTempPath { file =>
- val previousExecutionIds = sqlContext.listener.executionIdToData.keySet
- // Assume the execution plan is
- // PhysicalRDD(nodeId = 0)
- person.select('name).write.format("json").save(file.getAbsolutePath)
- sparkContext.listenerBus.waitUntilEmpty(10000)
- val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
- assert(executionIds.size === 1)
- val executionId = executionIds.head
- val jobs = sqlContext.listener.getExecution(executionId).get.jobs
- // Use "<=" because there is a race condition that we may miss some jobs
- // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event.
- assert(jobs.size <= 1)
- val metricValues = sqlContext.listener.getExecutionMetrics(executionId)
- // Because "save" will create a new DataFrame internally, we cannot get the real metric id.
- // However, we still can check the value.
- assert(metricValues.values.toSeq === Seq("2"))
+ withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
+ val previousExecutionIds = sqlContext.listener.executionIdToData.keySet
+ // Assume the execution plan is
+ // PhysicalRDD(nodeId = 0)
+ person.select('name).write.format("json").save(file.getAbsolutePath)
+ sparkContext.listenerBus.waitUntilEmpty(10000)
+ val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
+ assert(executionIds.size === 1)
+ val executionId = executionIds.head
+ val jobs = sqlContext.listener.getExecution(executionId).get.jobs
+ // Use "<=" because there is a race condition that we may miss some jobs
+ // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event.
+ assert(jobs.size <= 1)
+ val metricValues = sqlContext.listener.getExecutionMetrics(executionId)
+ // Because "save" will create a new DataFrame internally, we cannot get the real metric id.
+ // However, we still can check the value.
+ assert(metricValues.values.toSeq === Seq("2"))
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index d48143762c..7d6bff8295 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -199,7 +199,7 @@ private[sql] trait SQLTestUtils
val schema = df.schema
val childRDD = df
.queryExecution
- .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter]
+ .sparkPlan.asInstanceOf[org.apache.spark.sql.execution.Filter]
.child
.execute()
.map(row => Row.fromSeq(row.copy().toSeq(schema)))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index 9a24a2487a..a3e5243b68 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -97,10 +97,12 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
}
sqlContext.listenerManager.register(listener)
- val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count()
- df.collect()
- df.collect()
- Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect()
+ withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
+ val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count()
+ df.collect()
+ df.collect()
+ Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect()
+ }
assert(metrics.length == 3)
assert(metrics(0) == 1)