aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-01-28 13:51:55 -0800
committerDavies Liu <davies.liu@gmail.com>2016-01-28 13:51:55 -0800
commitcc18a7199240bf3b03410c1ba6704fe7ce6ae38e (patch)
treec28ed508523f45751172903ecd63130f6a3868c5
parent676803963fcc08aa988aa6f14be3751314e006ca (diff)
downloadspark-cc18a7199240bf3b03410c1ba6704fe7ce6ae38e.tar.gz
spark-cc18a7199240bf3b03410c1ba6704fe7ce6ae38e.tar.bz2
spark-cc18a7199240bf3b03410c1ba6704fe7ce6ae38e.zip
[SPARK-13031] [SQL] cleanup codegen and improve test coverage
1. enable whole stage codegen during tests even there is only one operator supports that. 2. split doProduce() into two APIs: upstream() and doProduce() 3. generate prefix for fresh names of each operator 4. pass UnsafeRow to parent directly (avoid getters and create UnsafeRow again) 5. fix bugs and tests. Author: Davies Liu <davies@databricks.com> Closes #10944 from davies/gen_refactor.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala188
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala88
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala96
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala103
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala34
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala10
9 files changed, 334 insertions, 202 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 2747c315ad..e6704cf8bb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -145,13 +145,22 @@ class CodegenContext {
private val curId = new java.util.concurrent.atomic.AtomicInteger()
/**
+ * A prefix used to generate fresh name.
+ */
+ var freshNamePrefix = ""
+
+ /**
* Returns a term name that is unique within this instance of a `CodeGenerator`.
*
* (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
* function.)
*/
- def freshName(prefix: String): String = {
- s"$prefix${curId.getAndIncrement}"
+ def freshName(name: String): String = {
+ if (freshNamePrefix == "") {
+ s"$name${curId.getAndIncrement}"
+ } else {
+ s"${freshNamePrefix}_$name${curId.getAndIncrement}"
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index d9fe76133c..ec31db19b9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -93,7 +93,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
// Can't call setNullAt on DecimalType, because we need to keep the offset
s"""
if (this.isNull_$i) {
- ${ctx.setColumn("mutableRow", e.dataType, i, null)};
+ ${ctx.setColumn("mutableRow", e.dataType, i, "null")};
} else {
${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 57f4945de9..ef81ba60f0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -22,9 +22,11 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.util.Utils
/**
* An interface for those physical operators that support codegen.
@@ -42,10 +44,16 @@ trait CodegenSupport extends SparkPlan {
private var parent: CodegenSupport = null
/**
- * Returns an input RDD of InternalRow and Java source code to process them.
+ * Returns the RDD of InternalRow which generates the input rows.
*/
- def produce(ctx: CodegenContext, parent: CodegenSupport): (RDD[InternalRow], String) = {
+ def upstream(): RDD[InternalRow]
+
+ /**
+ * Returns Java source code to process the rows from upstream.
+ */
+ def produce(ctx: CodegenContext, parent: CodegenSupport): String = {
this.parent = parent
+ ctx.freshNamePrefix = nodeName
doProduce(ctx)
}
@@ -66,16 +74,41 @@ trait CodegenSupport extends SparkPlan {
* # call consume(), wich will call parent.doConsume()
* }
*/
- protected def doProduce(ctx: CodegenContext): (RDD[InternalRow], String)
+ protected def doProduce(ctx: CodegenContext): String
/**
- * Consume the columns generated from current SparkPlan, call it's parent or create an iterator.
+ * Consume the columns generated from current SparkPlan, call it's parent.
*/
- protected def consume(ctx: CodegenContext, columns: Seq[ExprCode]): String = {
- assert(columns.length == output.length)
- parent.doConsume(ctx, this, columns)
+ def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = {
+ if (input != null) {
+ assert(input.length == output.length)
+ }
+ parent.consumeChild(ctx, this, input, row)
}
+ /**
+ * Consume the columns generated from it's child, call doConsume() or emit the rows.
+ */
+ def consumeChild(
+ ctx: CodegenContext,
+ child: SparkPlan,
+ input: Seq[ExprCode],
+ row: String = null): String = {
+ ctx.freshNamePrefix = nodeName
+ if (row != null) {
+ ctx.currentVars = null
+ ctx.INPUT_ROW = row
+ val evals = child.output.zipWithIndex.map { case (attr, i) =>
+ BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
+ }
+ s"""
+ | ${evals.map(_.code).mkString("\n")}
+ | ${doConsume(ctx, evals)}
+ """.stripMargin
+ } else {
+ doConsume(ctx, input)
+ }
+ }
/**
* Generate the Java source code to process the rows from child SparkPlan.
@@ -89,7 +122,9 @@ trait CodegenSupport extends SparkPlan {
* # call consume(), which will call parent.doConsume()
* }
*/
- def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String
+ protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ throw new UnsupportedOperationException
+ }
}
@@ -102,31 +137,36 @@ trait CodegenSupport extends SparkPlan {
case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
override def output: Seq[Attribute] = child.output
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+
+ override def doPrepare(): Unit = {
+ child.prepare()
+ }
- override def supportCodegen: Boolean = true
+ override def doExecute(): RDD[InternalRow] = {
+ child.execute()
+ }
- override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+ override def supportCodegen: Boolean = false
+
+ override def upstream(): RDD[InternalRow] = {
+ child.execute()
+ }
+
+ override def doProduce(ctx: CodegenContext): String = {
val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
val row = ctx.freshName("row")
ctx.INPUT_ROW = row
ctx.currentVars = null
val columns = exprs.map(_.gen(ctx))
- val code = s"""
- | while (input.hasNext()) {
+ s"""
+ | while (input.hasNext()) {
| InternalRow $row = (InternalRow) input.next();
| ${columns.map(_.code).mkString("\n")}
| ${consume(ctx, columns)}
| }
""".stripMargin
- (child.execute(), code)
- }
-
- def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
- throw new UnsupportedOperationException
- }
-
- override def doExecute(): RDD[InternalRow] = {
- throw new UnsupportedOperationException
}
override def simpleString: String = "INPUT"
@@ -143,16 +183,20 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
*
* -> execute()
* |
- * doExecute() --------> produce()
+ * doExecute() ---------> upstream() -------> upstream() ------> execute()
+ * |
+ * -----------------> produce()
* |
* doProduce() -------> produce()
* |
- * doProduce() ---> execute()
+ * doProduce()
* |
* consume()
- * doConsume() ------------|
+ * consumeChild() <-----------|
* |
- * doConsume() <----- consume()
+ * doConsume()
+ * |
+ * consumeChild() <----- consume()
*
* SparkPlan A should override doProduce() and doConsume().
*
@@ -162,37 +206,48 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
extends SparkPlan with CodegenSupport {
+ override def supportCodegen: Boolean = false
+
override def output: Seq[Attribute] = plan.output
+ override def outputPartitioning: Partitioning = plan.outputPartitioning
+ override def outputOrdering: Seq[SortOrder] = plan.outputOrdering
+
+ override def doPrepare(): Unit = {
+ plan.prepare()
+ }
override def doExecute(): RDD[InternalRow] = {
val ctx = new CodegenContext
- val (rdd, code) = plan.produce(ctx, this)
+ val code = plan.produce(ctx, this)
val references = ctx.references.toArray
val source = s"""
public Object generate(Object[] references) {
- return new GeneratedIterator(references);
+ return new GeneratedIterator(references);
}
class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
- private Object[] references;
- ${ctx.declareMutableStates()}
+ private Object[] references;
+ ${ctx.declareMutableStates()}
+ ${ctx.declareAddedFunctions()}
- public GeneratedIterator(Object[] references) {
+ public GeneratedIterator(Object[] references) {
this.references = references;
${ctx.initMutableStates()}
- }
+ }
- protected void processNext() {
+ protected void processNext() throws java.io.IOException {
$code
- }
+ }
}
- """
+ """
+
// try to compile, helpful for debug
// println(s"${CodeFormatter.format(source)}")
CodeGenerator.compile(source)
- rdd.mapPartitions { iter =>
+ plan.upstream().mapPartitions { iter =>
+
val clazz = CodeGenerator.compile(source)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
buffer.setInput(iter)
@@ -203,29 +258,47 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
}
}
- override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+ override def upstream(): RDD[InternalRow] = {
throw new UnsupportedOperationException
}
- override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
- if (input.nonEmpty) {
- val colExprs = output.zipWithIndex.map { case (attr, i) =>
- BoundReference(i, attr.dataType, attr.nullable)
- }
- // generate the code to create a UnsafeRow
- ctx.currentVars = input
- val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
- s"""
- | ${code.code.trim}
- | currentRow = ${code.value};
- | return;
- """.stripMargin
- } else {
- // There is no columns
+ override def doProduce(ctx: CodegenContext): String = {
+ throw new UnsupportedOperationException
+ }
+
+ override def consumeChild(
+ ctx: CodegenContext,
+ child: SparkPlan,
+ input: Seq[ExprCode],
+ row: String = null): String = {
+
+ if (row != null) {
+ // There is an UnsafeRow already
s"""
- | currentRow = unsafeRow;
+ | currentRow = $row;
| return;
""".stripMargin
+ } else {
+ assert(input != null)
+ if (input.nonEmpty) {
+ val colExprs = output.zipWithIndex.map { case (attr, i) =>
+ BoundReference(i, attr.dataType, attr.nullable)
+ }
+ // generate the code to create a UnsafeRow
+ ctx.currentVars = input
+ val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
+ s"""
+ | ${code.code.trim}
+ | currentRow = ${code.value};
+ | return;
+ """.stripMargin
+ } else {
+ // There is no columns
+ s"""
+ | currentRow = unsafeRow;
+ | return;
+ """.stripMargin
+ }
}
}
@@ -246,7 +319,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
builder.append(simpleString)
builder.append("\n")
- plan.generateTreeString(depth + 1, lastChildren :+children.isEmpty :+ true, builder)
+ plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder)
if (children.nonEmpty) {
children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
children.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
@@ -286,13 +359,14 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
case plan: CodegenSupport if supportCodegen(plan) &&
// Whole stage codegen is only useful when there are at least two levels of operators that
// support it (save at least one projection/iterator).
- plan.children.exists(supportCodegen) =>
+ (Utils.isTesting || plan.children.exists(supportCodegen)) =>
var inputs = ArrayBuffer[SparkPlan]()
val combined = plan.transform {
case p if !supportCodegen(p) =>
- inputs += p
- InputAdapter(p)
+ val input = apply(p) // collapse them recursively
+ inputs += input
+ InputAdapter(input)
}.asInstanceOf[CodegenSupport]
WholeStageCodegen(combined, inputs)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 23e54f344d..cbd2634b89 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -117,9 +117,7 @@ case class TungstenAggregate(
override def supportCodegen: Boolean = {
groupingExpressions.isEmpty &&
// ImperativeAggregate is not supported right now
- !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) &&
- // final aggregation only have one row, do not need to codegen
- !aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete)
+ !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
}
// The variables used as aggregation buffer
@@ -127,7 +125,11 @@ case class TungstenAggregate(
private val modes = aggregateExpressions.map(_.mode).distinct
- protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+ override def upstream(): RDD[InternalRow] = {
+ child.asInstanceOf[CodegenSupport].upstream()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
val initAgg = ctx.freshName("initAgg")
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
@@ -137,50 +139,80 @@ case class TungstenAggregate(
bufVars = initExpr.map { e =>
val isNull = ctx.freshName("bufIsNull")
val value = ctx.freshName("bufValue")
+ ctx.addMutableState("boolean", isNull, "")
+ ctx.addMutableState(ctx.javaType(e.dataType), value, "")
// The initial expression should not access any column
val ev = e.gen(ctx)
val initVars = s"""
- | boolean $isNull = ${ev.isNull};
- | ${ctx.javaType(e.dataType)} $value = ${ev.value};
+ | $isNull = ${ev.isNull};
+ | $value = ${ev.value};
""".stripMargin
ExprCode(ev.code + initVars, isNull, value)
}
- val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this)
- val source =
+ // generate variables for output
+ val (resultVars, genResult) = if (modes.contains(Final) | modes.contains(Complete)) {
+ // evaluate aggregate results
+ ctx.currentVars = bufVars
+ val bufferAttrs = functions.flatMap(_.aggBufferAttributes)
+ val aggResults = functions.map(_.evaluateExpression).map { e =>
+ BindReferences.bindReference(e, bufferAttrs).gen(ctx)
+ }
+ // evaluate result expressions
+ ctx.currentVars = aggResults
+ val resultVars = resultExpressions.map { e =>
+ BindReferences.bindReference(e, aggregateAttributes).gen(ctx)
+ }
+ (resultVars, s"""
+ | ${aggResults.map(_.code).mkString("\n")}
+ | ${resultVars.map(_.code).mkString("\n")}
+ """.stripMargin)
+ } else {
+ // output the aggregate buffer directly
+ (bufVars, "")
+ }
+
+ val doAgg = ctx.freshName("doAgg")
+ ctx.addNewFunction(doAgg,
s"""
- | if (!$initAgg) {
- | $initAgg = true;
- |
+ | private void $doAgg() {
| // initialize aggregation buffer
| ${bufVars.map(_.code).mkString("\n")}
|
- | $childSource
- |
- | // output the result
- | ${consume(ctx, bufVars)}
+ | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
| }
- """.stripMargin
+ """.stripMargin)
- (rdd, source)
+ s"""
+ | if (!$initAgg) {
+ | $initAgg = true;
+ | $doAgg();
+ |
+ | // output the result
+ | $genResult
+ |
+ | ${consume(ctx, resultVars)}
+ | }
+ """.stripMargin
}
- override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
// only have DeclarativeAggregate
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
- // the mode could be only Partial or PartialMerge
- val updateExpr = if (modes.contains(Partial)) {
- functions.flatMap(_.updateExpressions)
- } else {
- functions.flatMap(_.mergeExpressions)
+ val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output
+ val updateExpr = aggregateExpressions.flatMap { e =>
+ e.mode match {
+ case Partial | Complete =>
+ e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
+ case PartialMerge | Final =>
+ e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
+ }
}
- val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output
- val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr))
ctx.currentVars = bufVars ++ input
// TODO: support subexpression elimination
- val codes = boundExpr.zipWithIndex.map { case (e, i) =>
- val ev = e.gen(ctx)
+ val updates = updateExpr.zipWithIndex.map { case (e, i) =>
+ val ev = BindReferences.bindReference[Expression](e, inputAttrs).gen(ctx)
s"""
| ${ev.code}
| ${bufVars(i).isNull} = ${ev.isNull};
@@ -190,7 +222,7 @@ case class TungstenAggregate(
s"""
| // do aggregate and update aggregation buffer
- | ${codes.mkString("")}
+ | ${updates.mkString("")}
""".stripMargin
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 6deb72adad..e7a73d5fbb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -37,11 +37,15 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
- protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+ override def upstream(): RDD[InternalRow] = {
+ child.asInstanceOf[CodegenSupport].upstream()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
- override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val exprs = projectList.map(x =>
ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output)))
ctx.currentVars = input
@@ -76,11 +80,15 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
"numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
- protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+ override def upstream(): RDD[InternalRow] = {
+ child.asInstanceOf[CodegenSupport].upstream()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
- override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val expr = ExpressionCanonicalizer.execute(
BindReferences.bindReference(condition, child.output))
ctx.currentVars = input
@@ -153,17 +161,21 @@ case class Range(
output: Seq[Attribute])
extends LeafNode with CodegenSupport {
- protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
- val initTerm = ctx.freshName("range_initRange")
+ override def upstream(): RDD[InternalRow] = {
+ sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i))
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
+ val initTerm = ctx.freshName("initRange")
ctx.addMutableState("boolean", initTerm, s"$initTerm = false;")
- val partitionEnd = ctx.freshName("range_partitionEnd")
+ val partitionEnd = ctx.freshName("partitionEnd")
ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;")
- val number = ctx.freshName("range_number")
+ val number = ctx.freshName("number")
ctx.addMutableState("long", number, s"$number = 0L;")
- val overflow = ctx.freshName("range_overflow")
+ val overflow = ctx.freshName("overflow")
ctx.addMutableState("boolean", overflow, s"$overflow = false;")
- val value = ctx.freshName("range_value")
+ val value = ctx.freshName("value")
val ev = ExprCode("", "false", value)
val BigInt = classOf[java.math.BigInteger].getName
val checkEnd = if (step > 0) {
@@ -172,38 +184,42 @@ case class Range(
s"$number > $partitionEnd"
}
- val rdd = sqlContext.sparkContext.parallelize(0 until numSlices, numSlices)
- .map(i => InternalRow(i))
+ ctx.addNewFunction("initRange",
+ s"""
+ | private void initRange(int idx) {
+ | $BigInt index = $BigInt.valueOf(idx);
+ | $BigInt numSlice = $BigInt.valueOf(${numSlices}L);
+ | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L);
+ | $BigInt step = $BigInt.valueOf(${step}L);
+ | $BigInt start = $BigInt.valueOf(${start}L);
+ |
+ | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
+ | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
+ | $number = Long.MAX_VALUE;
+ | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
+ | $number = Long.MIN_VALUE;
+ | } else {
+ | $number = st.longValue();
+ | }
+ |
+ | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
+ | .multiply(step).add(start);
+ | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
+ | $partitionEnd = Long.MAX_VALUE;
+ | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
+ | $partitionEnd = Long.MIN_VALUE;
+ | } else {
+ | $partitionEnd = end.longValue();
+ | }
+ | }
+ """.stripMargin)
- val code = s"""
+ s"""
| // initialize Range
| if (!$initTerm) {
| $initTerm = true;
| if (input.hasNext()) {
- | $BigInt index = $BigInt.valueOf(((InternalRow) input.next()).getInt(0));
- | $BigInt numSlice = $BigInt.valueOf(${numSlices}L);
- | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L);
- | $BigInt step = $BigInt.valueOf(${step}L);
- | $BigInt start = $BigInt.valueOf(${start}L);
- |
- | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
- | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
- | $number = Long.MAX_VALUE;
- | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
- | $number = Long.MIN_VALUE;
- | } else {
- | $number = st.longValue();
- | }
- |
- | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
- | .multiply(step).add(start);
- | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
- | $partitionEnd = Long.MAX_VALUE;
- | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
- | $partitionEnd = Long.MIN_VALUE;
- | } else {
- | $partitionEnd = end.longValue();
- | }
+ | initRange(((InternalRow) input.next()).getInt(0));
| } else {
| return;
| }
@@ -218,12 +234,6 @@ case class Range(
| ${consume(ctx, Seq(ev))}
| }
""".stripMargin
-
- (rdd, code)
- }
-
- def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
- throw new UnsupportedOperationException
}
protected override def doExecute(): RDD[InternalRow] = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 989cb29429..51a50c1fa3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1939,58 +1939,61 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("Common subexpression elimination") {
- // select from a table to prevent constant folding.
- val df = sql("SELECT a, b from testData2 limit 1")
- checkAnswer(df, Row(1, 1))
-
- checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2))
- checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3))
-
- // This does not work because the expressions get grouped like (a + a) + 1
- checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3))
- checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3))
-
- // Identity udf that tracks the number of times it is called.
- val countAcc = sparkContext.accumulator(0, "CallCount")
- sqlContext.udf.register("testUdf", (x: Int) => {
- countAcc.++=(1)
- x
- })
+ // TODO: support subexpression elimination in whole stage codegen
+ withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
+ // select from a table to prevent constant folding.
+ val df = sql("SELECT a, b from testData2 limit 1")
+ checkAnswer(df, Row(1, 1))
+
+ checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2))
+ checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3))
+
+ // This does not work because the expressions get grouped like (a + a) + 1
+ checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3))
+ checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3))
+
+ // Identity udf that tracks the number of times it is called.
+ val countAcc = sparkContext.accumulator(0, "CallCount")
+ sqlContext.udf.register("testUdf", (x: Int) => {
+ countAcc.++=(1)
+ x
+ })
+
+ // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value
+ // is correct.
+ def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = {
+ countAcc.setValue(0)
+ checkAnswer(df, expectedResult)
+ assert(countAcc.value == expectedCount)
+ }
- // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value
- // is correct.
- def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = {
- countAcc.setValue(0)
- checkAnswer(df, expectedResult)
- assert(countAcc.value == expectedCount)
+ verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1)
+ verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
+ verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1)
+ verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2)
+ verifyCallCount(
+ df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1)
+
+ verifyCallCount(
+ df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2)
+
+ val testUdf = functions.udf((x: Int) => {
+ countAcc.++=(1)
+ x
+ })
+ verifyCallCount(
+ df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1)
+
+ // Would be nice if semantic equals for `+` understood commutative
+ verifyCallCount(
+ df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2)
+
+ // Try disabling it via configuration.
+ sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false")
+ verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2)
+ sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true")
+ verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
}
-
- verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1)
- verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
- verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1)
- verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2)
- verifyCallCount(
- df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1)
-
- verifyCallCount(
- df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2)
-
- val testUdf = functions.udf((x: Int) => {
- countAcc.++=(1)
- x
- })
- verifyCallCount(
- df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1)
-
- // Would be nice if semantic equals for `+` understood commutative
- verifyCallCount(
- df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2)
-
- // Try disabling it via configuration.
- sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false")
- verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2)
- sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true")
- verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
}
test("SPARK-10707: nullability should be correctly propagated through set operations (1)") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index cbae19ebd2..82f6811503 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -335,22 +335,24 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
test("save metrics") {
withTempPath { file =>
- val previousExecutionIds = sqlContext.listener.executionIdToData.keySet
- // Assume the execution plan is
- // PhysicalRDD(nodeId = 0)
- person.select('name).write.format("json").save(file.getAbsolutePath)
- sparkContext.listenerBus.waitUntilEmpty(10000)
- val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
- assert(executionIds.size === 1)
- val executionId = executionIds.head
- val jobs = sqlContext.listener.getExecution(executionId).get.jobs
- // Use "<=" because there is a race condition that we may miss some jobs
- // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event.
- assert(jobs.size <= 1)
- val metricValues = sqlContext.listener.getExecutionMetrics(executionId)
- // Because "save" will create a new DataFrame internally, we cannot get the real metric id.
- // However, we still can check the value.
- assert(metricValues.values.toSeq === Seq("2"))
+ withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
+ val previousExecutionIds = sqlContext.listener.executionIdToData.keySet
+ // Assume the execution plan is
+ // PhysicalRDD(nodeId = 0)
+ person.select('name).write.format("json").save(file.getAbsolutePath)
+ sparkContext.listenerBus.waitUntilEmpty(10000)
+ val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
+ assert(executionIds.size === 1)
+ val executionId = executionIds.head
+ val jobs = sqlContext.listener.getExecution(executionId).get.jobs
+ // Use "<=" because there is a race condition that we may miss some jobs
+ // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event.
+ assert(jobs.size <= 1)
+ val metricValues = sqlContext.listener.getExecutionMetrics(executionId)
+ // Because "save" will create a new DataFrame internally, we cannot get the real metric id.
+ // However, we still can check the value.
+ assert(metricValues.values.toSeq === Seq("2"))
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index d48143762c..7d6bff8295 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -199,7 +199,7 @@ private[sql] trait SQLTestUtils
val schema = df.schema
val childRDD = df
.queryExecution
- .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter]
+ .sparkPlan.asInstanceOf[org.apache.spark.sql.execution.Filter]
.child
.execute()
.map(row => Row.fromSeq(row.copy().toSeq(schema)))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index 9a24a2487a..a3e5243b68 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -97,10 +97,12 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
}
sqlContext.listenerManager.register(listener)
- val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count()
- df.collect()
- df.collect()
- Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect()
+ withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
+ val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count()
+ df.collect()
+ df.collect()
+ Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect()
+ }
assert(metrics.length == 3)
assert(metrics(0) == 1)