From b9dfdcc63bb12bc24de96060e756889c2ceda519 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 28 Jan 2016 17:01:12 -0800 Subject: Revert "[SPARK-13031] [SQL] cleanup codegen and improve test coverage" This reverts commit cc18a7199240bf3b03410c1ba6704fe7ce6ae38e. --- .../expressions/codegen/CodeGenerator.scala | 13 +- .../codegen/GenerateMutableProjection.scala | 2 +- .../spark/sql/execution/WholeStageCodegen.scala | 188 +++++++-------------- .../execution/aggregate/TungstenAggregate.scala | 88 +++------- .../spark/sql/execution/basicOperators.scala | 96 +++++------ .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 103 ++++++----- .../sql/execution/metric/SQLMetricsSuite.scala | 34 ++-- .../org/apache/spark/sql/test/SQLTestUtils.scala | 2 +- .../spark/sql/util/DataFrameCallbackSuite.scala | 10 +- 9 files changed, 202 insertions(+), 334 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e6704cf8bb..2747c315ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -144,23 +144,14 @@ class CodegenContext { private val curId = new java.util.concurrent.atomic.AtomicInteger() - /** - * A prefix used to generate fresh name. - */ - var freshNamePrefix = "" - /** * Returns a term name that is unique within this instance of a `CodeGenerator`. * * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` * function.) */ - def freshName(name: String): String = { - if (freshNamePrefix == "") { - s"$name${curId.getAndIncrement}" - } else { - s"${freshNamePrefix}_$name${curId.getAndIncrement}" - } + def freshName(prefix: String): String = { + s"$prefix${curId.getAndIncrement}" } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index ec31db19b9..d9fe76133c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -93,7 +93,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu // Can't call setNullAt on DecimalType, because we need to keep the offset s""" if (this.isNull_$i) { - ${ctx.setColumn("mutableRow", e.dataType, i, "null")}; + ${ctx.setColumn("mutableRow", e.dataType, i, null)}; } else { ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index ef81ba60f0..57f4945de9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -22,11 +22,9 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression} import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.util.Utils /** * An interface for those physical operators that support codegen. @@ -44,16 +42,10 @@ trait CodegenSupport extends SparkPlan { private var parent: CodegenSupport = null /** - * Returns the RDD of InternalRow which generates the input rows. + * Returns an input RDD of InternalRow and Java source code to process them. */ - def upstream(): RDD[InternalRow] - - /** - * Returns Java source code to process the rows from upstream. - */ - def produce(ctx: CodegenContext, parent: CodegenSupport): String = { + def produce(ctx: CodegenContext, parent: CodegenSupport): (RDD[InternalRow], String) = { this.parent = parent - ctx.freshNamePrefix = nodeName doProduce(ctx) } @@ -74,41 +66,16 @@ trait CodegenSupport extends SparkPlan { * # call consume(), wich will call parent.doConsume() * } */ - protected def doProduce(ctx: CodegenContext): String + protected def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) /** - * Consume the columns generated from current SparkPlan, call it's parent. + * Consume the columns generated from current SparkPlan, call it's parent or create an iterator. */ - def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { - if (input != null) { - assert(input.length == output.length) - } - parent.consumeChild(ctx, this, input, row) + protected def consume(ctx: CodegenContext, columns: Seq[ExprCode]): String = { + assert(columns.length == output.length) + parent.doConsume(ctx, this, columns) } - /** - * Consume the columns generated from it's child, call doConsume() or emit the rows. - */ - def consumeChild( - ctx: CodegenContext, - child: SparkPlan, - input: Seq[ExprCode], - row: String = null): String = { - ctx.freshNamePrefix = nodeName - if (row != null) { - ctx.currentVars = null - ctx.INPUT_ROW = row - val evals = child.output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable).gen(ctx) - } - s""" - | ${evals.map(_.code).mkString("\n")} - | ${doConsume(ctx, evals)} - """.stripMargin - } else { - doConsume(ctx, input) - } - } /** * Generate the Java source code to process the rows from child SparkPlan. @@ -122,9 +89,7 @@ trait CodegenSupport extends SparkPlan { * # call consume(), which will call parent.doConsume() * } */ - protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { - throw new UnsupportedOperationException - } + def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String } @@ -137,36 +102,31 @@ trait CodegenSupport extends SparkPlan { case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { override def output: Seq[Attribute] = child.output - override def outputPartitioning: Partitioning = child.outputPartitioning - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - - override def doPrepare(): Unit = { - child.prepare() - } - override def doExecute(): RDD[InternalRow] = { - child.execute() - } + override def supportCodegen: Boolean = true - override def supportCodegen: Boolean = false - - override def upstream(): RDD[InternalRow] = { - child.execute() - } - - override def doProduce(ctx: CodegenContext): String = { + override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) val row = ctx.freshName("row") ctx.INPUT_ROW = row ctx.currentVars = null val columns = exprs.map(_.gen(ctx)) - s""" - | while (input.hasNext()) { + val code = s""" + | while (input.hasNext()) { | InternalRow $row = (InternalRow) input.next(); | ${columns.map(_.code).mkString("\n")} | ${consume(ctx, columns)} | } """.stripMargin + (child.execute(), code) + } + + def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + throw new UnsupportedOperationException + } + + override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException } override def simpleString: String = "INPUT" @@ -183,20 +143,16 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { * * -> execute() * | - * doExecute() ---------> upstream() -------> upstream() ------> execute() - * | - * -----------------> produce() + * doExecute() --------> produce() * | * doProduce() -------> produce() * | - * doProduce() + * doProduce() ---> execute() * | * consume() - * consumeChild() <-----------| + * doConsume() ------------| * | - * doConsume() - * | - * consumeChild() <----- consume() + * doConsume() <----- consume() * * SparkPlan A should override doProduce() and doConsume(). * @@ -206,48 +162,37 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) extends SparkPlan with CodegenSupport { - override def supportCodegen: Boolean = false - override def output: Seq[Attribute] = plan.output - override def outputPartitioning: Partitioning = plan.outputPartitioning - override def outputOrdering: Seq[SortOrder] = plan.outputOrdering - - override def doPrepare(): Unit = { - plan.prepare() - } override def doExecute(): RDD[InternalRow] = { val ctx = new CodegenContext - val code = plan.produce(ctx, this) + val (rdd, code) = plan.produce(ctx, this) val references = ctx.references.toArray val source = s""" public Object generate(Object[] references) { - return new GeneratedIterator(references); + return new GeneratedIterator(references); } class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { - private Object[] references; - ${ctx.declareMutableStates()} - ${ctx.declareAddedFunctions()} + private Object[] references; + ${ctx.declareMutableStates()} - public GeneratedIterator(Object[] references) { + public GeneratedIterator(Object[] references) { this.references = references; ${ctx.initMutableStates()} - } + } - protected void processNext() throws java.io.IOException { + protected void processNext() { $code - } + } } - """ - + """ // try to compile, helpful for debug // println(s"${CodeFormatter.format(source)}") CodeGenerator.compile(source) - plan.upstream().mapPartitions { iter => - + rdd.mapPartitions { iter => val clazz = CodeGenerator.compile(source) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.setInput(iter) @@ -258,47 +203,29 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) } } - override def upstream(): RDD[InternalRow] = { + override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { throw new UnsupportedOperationException } - override def doProduce(ctx: CodegenContext): String = { - throw new UnsupportedOperationException - } - - override def consumeChild( - ctx: CodegenContext, - child: SparkPlan, - input: Seq[ExprCode], - row: String = null): String = { - - if (row != null) { - // There is an UnsafeRow already + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + if (input.nonEmpty) { + val colExprs = output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } + // generate the code to create a UnsafeRow + ctx.currentVars = input + val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) s""" - | currentRow = $row; + | ${code.code.trim} + | currentRow = ${code.value}; | return; - """.stripMargin + """.stripMargin } else { - assert(input != null) - if (input.nonEmpty) { - val colExprs = output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable) - } - // generate the code to create a UnsafeRow - ctx.currentVars = input - val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) - s""" - | ${code.code.trim} - | currentRow = ${code.value}; - | return; - """.stripMargin - } else { - // There is no columns - s""" - | currentRow = unsafeRow; - | return; - """.stripMargin - } + // There is no columns + s""" + | currentRow = unsafeRow; + | return; + """.stripMargin } } @@ -319,7 +246,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) builder.append(simpleString) builder.append("\n") - plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder) + plan.generateTreeString(depth + 1, lastChildren :+children.isEmpty :+ true, builder) if (children.nonEmpty) { children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) children.last.generateTreeString(depth + 1, lastChildren :+ true, builder) @@ -359,14 +286,13 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru case plan: CodegenSupport if supportCodegen(plan) && // Whole stage codegen is only useful when there are at least two levels of operators that // support it (save at least one projection/iterator). - (Utils.isTesting || plan.children.exists(supportCodegen)) => + plan.children.exists(supportCodegen) => var inputs = ArrayBuffer[SparkPlan]() val combined = plan.transform { case p if !supportCodegen(p) => - val input = apply(p) // collapse them recursively - inputs += input - InputAdapter(input) + inputs += p + InputAdapter(p) }.asInstanceOf[CodegenSupport] WholeStageCodegen(combined, inputs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index cbd2634b89..23e54f344d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -117,7 +117,9 @@ case class TungstenAggregate( override def supportCodegen: Boolean = { groupingExpressions.isEmpty && // ImperativeAggregate is not supported right now - !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) + !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) && + // final aggregation only have one row, do not need to codegen + !aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete) } // The variables used as aggregation buffer @@ -125,11 +127,7 @@ case class TungstenAggregate( private val modes = aggregateExpressions.map(_.mode).distinct - override def upstream(): RDD[InternalRow] = { - child.asInstanceOf[CodegenSupport].upstream() - } - - protected override def doProduce(ctx: CodegenContext): String = { + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { val initAgg = ctx.freshName("initAgg") ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") @@ -139,80 +137,50 @@ case class TungstenAggregate( bufVars = initExpr.map { e => val isNull = ctx.freshName("bufIsNull") val value = ctx.freshName("bufValue") - ctx.addMutableState("boolean", isNull, "") - ctx.addMutableState(ctx.javaType(e.dataType), value, "") // The initial expression should not access any column val ev = e.gen(ctx) val initVars = s""" - | $isNull = ${ev.isNull}; - | $value = ${ev.value}; + | boolean $isNull = ${ev.isNull}; + | ${ctx.javaType(e.dataType)} $value = ${ev.value}; """.stripMargin ExprCode(ev.code + initVars, isNull, value) } - // generate variables for output - val (resultVars, genResult) = if (modes.contains(Final) | modes.contains(Complete)) { - // evaluate aggregate results - ctx.currentVars = bufVars - val bufferAttrs = functions.flatMap(_.aggBufferAttributes) - val aggResults = functions.map(_.evaluateExpression).map { e => - BindReferences.bindReference(e, bufferAttrs).gen(ctx) - } - // evaluate result expressions - ctx.currentVars = aggResults - val resultVars = resultExpressions.map { e => - BindReferences.bindReference(e, aggregateAttributes).gen(ctx) - } - (resultVars, s""" - | ${aggResults.map(_.code).mkString("\n")} - | ${resultVars.map(_.code).mkString("\n")} - """.stripMargin) - } else { - // output the aggregate buffer directly - (bufVars, "") - } - - val doAgg = ctx.freshName("doAgg") - ctx.addNewFunction(doAgg, + val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this) + val source = s""" - | private void $doAgg() { + | if (!$initAgg) { + | $initAgg = true; + | | // initialize aggregation buffer | ${bufVars.map(_.code).mkString("\n")} | - | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + | $childSource + | + | // output the result + | ${consume(ctx, bufVars)} | } - """.stripMargin) + """.stripMargin - s""" - | if (!$initAgg) { - | $initAgg = true; - | $doAgg(); - | - | // output the result - | $genResult - | - | ${consume(ctx, resultVars)} - | } - """.stripMargin + (rdd, source) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output - val updateExpr = aggregateExpressions.flatMap { e => - e.mode match { - case Partial | Complete => - e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions - case PartialMerge | Final => - e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions - } + // the mode could be only Partial or PartialMerge + val updateExpr = if (modes.contains(Partial)) { + functions.flatMap(_.updateExpressions) + } else { + functions.flatMap(_.mergeExpressions) } + val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output + val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr)) ctx.currentVars = bufVars ++ input // TODO: support subexpression elimination - val updates = updateExpr.zipWithIndex.map { case (e, i) => - val ev = BindReferences.bindReference[Expression](e, inputAttrs).gen(ctx) + val codes = boundExpr.zipWithIndex.map { case (e, i) => + val ev = e.gen(ctx) s""" | ${ev.code} | ${bufVars(i).isNull} = ${ev.isNull}; @@ -222,7 +190,7 @@ case class TungstenAggregate( s""" | // do aggregate and update aggregation buffer - | ${updates.mkString("")} + | ${codes.mkString("")} """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index e7a73d5fbb..6deb72adad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -37,15 +37,11 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) override def output: Seq[Attribute] = projectList.map(_.toAttribute) - override def upstream(): RDD[InternalRow] = { - child.asInstanceOf[CodegenSupport].upstream() - } - - protected override def doProduce(ctx: CodegenContext): String = { + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { val exprs = projectList.map(x => ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) ctx.currentVars = input @@ -80,15 +76,11 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - override def upstream(): RDD[InternalRow] = { - child.asInstanceOf[CodegenSupport].upstream() - } - - protected override def doProduce(ctx: CodegenContext): String = { + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { val expr = ExpressionCanonicalizer.execute( BindReferences.bindReference(condition, child.output)) ctx.currentVars = input @@ -161,21 +153,17 @@ case class Range( output: Seq[Attribute]) extends LeafNode with CodegenSupport { - override def upstream(): RDD[InternalRow] = { - sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i)) - } - - protected override def doProduce(ctx: CodegenContext): String = { - val initTerm = ctx.freshName("initRange") + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + val initTerm = ctx.freshName("range_initRange") ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") - val partitionEnd = ctx.freshName("partitionEnd") + val partitionEnd = ctx.freshName("range_partitionEnd") ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;") - val number = ctx.freshName("number") + val number = ctx.freshName("range_number") ctx.addMutableState("long", number, s"$number = 0L;") - val overflow = ctx.freshName("overflow") + val overflow = ctx.freshName("range_overflow") ctx.addMutableState("boolean", overflow, s"$overflow = false;") - val value = ctx.freshName("value") + val value = ctx.freshName("range_value") val ev = ExprCode("", "false", value) val BigInt = classOf[java.math.BigInteger].getName val checkEnd = if (step > 0) { @@ -184,42 +172,38 @@ case class Range( s"$number > $partitionEnd" } - ctx.addNewFunction("initRange", - s""" - | private void initRange(int idx) { - | $BigInt index = $BigInt.valueOf(idx); - | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); - | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); - | $BigInt step = $BigInt.valueOf(${step}L); - | $BigInt start = $BigInt.valueOf(${start}L); - | - | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); - | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $number = Long.MAX_VALUE; - | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $number = Long.MIN_VALUE; - | } else { - | $number = st.longValue(); - | } - | - | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) - | .multiply(step).add(start); - | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $partitionEnd = Long.MAX_VALUE; - | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $partitionEnd = Long.MIN_VALUE; - | } else { - | $partitionEnd = end.longValue(); - | } - | } - """.stripMargin) + val rdd = sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) + .map(i => InternalRow(i)) - s""" + val code = s""" | // initialize Range | if (!$initTerm) { | $initTerm = true; | if (input.hasNext()) { - | initRange(((InternalRow) input.next()).getInt(0)); + | $BigInt index = $BigInt.valueOf(((InternalRow) input.next()).getInt(0)); + | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); + | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); + | $BigInt step = $BigInt.valueOf(${step}L); + | $BigInt start = $BigInt.valueOf(${start}L); + | + | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); + | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $number = Long.MAX_VALUE; + | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $number = Long.MIN_VALUE; + | } else { + | $number = st.longValue(); + | } + | + | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) + | .multiply(step).add(start); + | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $partitionEnd = Long.MAX_VALUE; + | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $partitionEnd = Long.MIN_VALUE; + | } else { + | $partitionEnd = end.longValue(); + | } | } else { | return; | } @@ -234,6 +218,12 @@ case class Range( | ${consume(ctx, Seq(ev))} | } """.stripMargin + + (rdd, code) + } + + def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + throw new UnsupportedOperationException } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 51a50c1fa3..989cb29429 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1939,61 +1939,58 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("Common subexpression elimination") { - // TODO: support subexpression elimination in whole stage codegen - withSQLConf("spark.sql.codegen.wholeStage" -> "false") { - // select from a table to prevent constant folding. - val df = sql("SELECT a, b from testData2 limit 1") - checkAnswer(df, Row(1, 1)) - - checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) - checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) - - // This does not work because the expressions get grouped like (a + a) + 1 - checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) - checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) - - // Identity udf that tracks the number of times it is called. - val countAcc = sparkContext.accumulator(0, "CallCount") - sqlContext.udf.register("testUdf", (x: Int) => { - countAcc.++=(1) - x - }) - - // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value - // is correct. - def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { - countAcc.setValue(0) - checkAnswer(df, expectedResult) - assert(countAcc.value == expectedCount) - } + // select from a table to prevent constant folding. + val df = sql("SELECT a, b from testData2 limit 1") + checkAnswer(df, Row(1, 1)) + + checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) + checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) + + // This does not work because the expressions get grouped like (a + a) + 1 + checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) + checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) + + // Identity udf that tracks the number of times it is called. + val countAcc = sparkContext.accumulator(0, "CallCount") + sqlContext.udf.register("testUdf", (x: Int) => { + countAcc.++=(1) + x + }) - verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) - verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) - verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) - - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) - - val testUdf = functions.udf((x: Int) => { - countAcc.++=(1) - x - }) - verifyCallCount( - df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) - - // Would be nice if semantic equals for `+` understood commutative - verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) - - // Try disabling it via configuration. - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") - verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value + // is correct. + def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { + countAcc.setValue(0) + checkAnswer(df, expectedResult) + assert(countAcc.value == expectedCount) } + + verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) + + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) + + val testUdf = functions.udf((x: Int) => { + countAcc.++=(1) + x + }) + verifyCallCount( + df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) + + // Would be nice if semantic equals for `+` understood commutative + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) + + // Try disabling it via configuration. + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } test("SPARK-10707: nullability should be correctly propagated through set operations (1)") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 82f6811503..cbae19ebd2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -335,24 +335,22 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("save metrics") { withTempPath { file => - withSQLConf("spark.sql.codegen.wholeStage" -> "false") { - val previousExecutionIds = sqlContext.listener.executionIdToData.keySet - // Assume the execution plan is - // PhysicalRDD(nodeId = 0) - person.select('name).write.format("json").save(file.getAbsolutePath) - sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) - assert(executionIds.size === 1) - val executionId = executionIds.head - val jobs = sqlContext.listener.getExecution(executionId).get.jobs - // Use "<=" because there is a race condition that we may miss some jobs - // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. - assert(jobs.size <= 1) - val metricValues = sqlContext.listener.getExecutionMetrics(executionId) - // Because "save" will create a new DataFrame internally, we cannot get the real metric id. - // However, we still can check the value. - assert(metricValues.values.toSeq === Seq("2")) - } + val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + // Assume the execution plan is + // PhysicalRDD(nodeId = 0) + person.select('name).write.format("json").save(file.getAbsolutePath) + sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = sqlContext.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= 1) + val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + // Because "save" will create a new DataFrame internally, we cannot get the real metric id. + // However, we still can check the value. + assert(metricValues.values.toSeq === Seq("2")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 7d6bff8295..d48143762c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -199,7 +199,7 @@ private[sql] trait SQLTestUtils val schema = df.schema val childRDD = df .queryExecution - .sparkPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] + .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] .child .execute() .map(row => Row.fromSeq(row.copy().toSeq(schema))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index a3e5243b68..9a24a2487a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -97,12 +97,10 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { } sqlContext.listenerManager.register(listener) - withSQLConf("spark.sql.codegen.wholeStage" -> "false") { - val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() - df.collect() - df.collect() - Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() - } + val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() + df.collect() + df.collect() + Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() assert(metrics.length == 3) assert(metrics(0) == 1) -- cgit v1.2.3