From e86f8f63bfa3c15659b94e831b853b1bc9ddae32 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 2 Feb 2016 22:13:10 -0800 Subject: [SPARK-13147] [SQL] improve readability of generated code 1. try to avoid the suffix (unique id) 2. remove the comment if there is no code generated. 3. re-arrange the order of functions 4. trop the new line for inlined blocks. Author: Davies Liu Closes #11032 from davies/better_suffix. --- .../sql/catalyst/expressions/Expression.scala | 8 ++++-- .../expressions/codegen/CodeGenerator.scala | 27 ++++++++++++------- .../expressions/complexTypeExtractors.scala | 31 +++++++++++++--------- .../spark/sql/execution/WholeStageCodegen.scala | 13 ++++----- .../execution/aggregate/TungstenAggregate.scala | 14 +++++----- .../spark/sql/execution/basicOperators.scala | 7 ++++- .../sql/execution/BenchmarkWholeStageCodegen.scala | 2 +- 7 files changed, 63 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 353fb92581..c73b2f8f2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -103,8 +103,12 @@ abstract class Expression extends TreeNode[Expression] { val value = ctx.freshName("value") val ve = ExprCode("", isNull, value) ve.code = genCode(ctx, ve) - // Add `this` in the comment. - ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim) + if (ve.code != "") { + // Add `this` in the comment. + ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim) + } else { + ve + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index a30aba1617..63e19564dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -156,7 +156,11 @@ class CodegenContext { /** The variable name of the input row in generated code. */ final var INPUT_ROW = "i" - private val curId = new java.util.concurrent.atomic.AtomicInteger() + /** + * The map from a variable name to it's next ID. + */ + private val freshNameIds = new mutable.HashMap[String, Int] + freshNameIds += INPUT_ROW -> 1 /** * A prefix used to generate fresh name. @@ -164,16 +168,21 @@ class CodegenContext { var freshNamePrefix = "" /** - * Returns a term name that is unique within this instance of a `CodeGenerator`. - * - * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` - * function.) + * Returns a term name that is unique within this instance of a `CodegenContext`. */ - def freshName(name: String): String = { - if (freshNamePrefix == "") { - s"$name${curId.getAndIncrement}" + def freshName(name: String): String = synchronized { + val fullName = if (freshNamePrefix == "") { + name + } else { + s"${freshNamePrefix}_$name" + } + if (freshNameIds.contains(fullName)) { + val id = freshNameIds(fullName) + freshNameIds(fullName) = id + 1 + s"$fullName$id" } else { - s"${freshNamePrefix}_$name${curId.getAndIncrement}" + freshNameIds += fullName -> 1 + fullName } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 9f2f82d68c..6b24fae9f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -173,22 +173,26 @@ case class GetArrayStructFields( override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, eval => { + val n = ctx.freshName("n") + val values = ctx.freshName("values") + val j = ctx.freshName("j") + val row = ctx.freshName("row") s""" - final int n = $eval.numElements(); - final Object[] values = new Object[n]; - for (int j = 0; j < n; j++) { - if ($eval.isNullAt(j)) { - values[j] = null; + final int $n = $eval.numElements(); + final Object[] $values = new Object[$n]; + for (int $j = 0; $j < $n; $j++) { + if ($eval.isNullAt($j)) { + $values[$j] = null; } else { - final InternalRow row = $eval.getStruct(j, $numFields); - if (row.isNullAt($ordinal)) { - values[j] = null; + final InternalRow $row = $eval.getStruct($j, $numFields); + if ($row.isNullAt($ordinal)) { + $values[$j] = null; } else { - values[j] = ${ctx.getValue("row", field.dataType, ordinal.toString)}; + $values[$j] = ${ctx.getValue(row, field.dataType, ordinal.toString)}; } } } - ${ev.value} = new $arrayClass(values); + ${ev.value} = new $arrayClass($values); """ }) } @@ -227,12 +231,13 @@ case class GetArrayItem(child: Expression, ordinal: Expression) override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val index = ctx.freshName("index") s""" - final int index = (int) $eval2; - if (index >= $eval1.numElements() || index < 0 || $eval1.isNullAt(index)) { + final int $index = (int) $eval2; + if ($index >= $eval1.numElements() || $index < 0 || $eval1.isNullAt($index)) { ${ev.isNull} = true; } else { - ${ev.value} = ${ctx.getValue(eval1, dataType, "index")}; + ${ev.value} = ${ctx.getValue(eval1, dataType, index)}; } """ }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 02b0f423ed..1475496907 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -170,8 +170,8 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { s""" | while (input.hasNext()) { | InternalRow $row = (InternalRow) input.next(); - | ${columns.map(_.code).mkString("\n")} - | ${consume(ctx, columns)} + | ${columns.map(_.code).mkString("\n").trim} + | ${consume(ctx, columns).trim} | } """.stripMargin } @@ -236,15 +236,16 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) private Object[] references; ${ctx.declareMutableStates()} - ${ctx.declareAddedFunctions()} public GeneratedIterator(Object[] references) { - this.references = references; - ${ctx.initMutableStates()} + this.references = references; + ${ctx.initMutableStates()} } + ${ctx.declareAddedFunctions()} + protected void processNext() throws java.io.IOException { - $code + ${code.trim} } } """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index f61db8594d..d024477061 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -211,9 +211,9 @@ case class TungstenAggregate( | $doAgg(); | | // output the result - | $genResult + | ${genResult.trim} | - | ${consume(ctx, resultVars)} + | ${consume(ctx, resultVars).trim} | } """.stripMargin } @@ -242,9 +242,9 @@ case class TungstenAggregate( } s""" | // do aggregate - | ${aggVals.map(_.code).mkString("\n")} + | ${aggVals.map(_.code).mkString("\n").trim} | // update aggregation buffer - | ${updates.mkString("")} + | ${updates.mkString("\n").trim} """.stripMargin } @@ -523,7 +523,7 @@ case class TungstenAggregate( // Finally, sort the spilled aggregate buffers by key, and merge them together for same key. s""" // generate grouping key - ${keyCode.code} + ${keyCode.code.trim} UnsafeRow $buffer = null; if ($checkFallback) { // try to get the buffer from hash map @@ -547,9 +547,9 @@ case class TungstenAggregate( $incCounter // evaluate aggregate function - ${evals.map(_.code).mkString("\n")} + ${evals.map(_.code).mkString("\n").trim} // update aggregate buffer - ${updates.mkString("\n")} + ${updates.mkString("\n").trim} """ } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index fd81531c93..ae4422195c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -93,9 +93,14 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit BindReferences.bindReference(condition, child.output)) ctx.currentVars = input val eval = expr.gen(ctx) + val nullCheck = if (expr.nullable) { + s"!${eval.isNull} &&" + } else { + s"" + } s""" | ${eval.code} - | if (!${eval.isNull} && ${eval.value}) { + | if ($nullCheck ${eval.value}) { | ${consume(ctx, ctx.currentVars)} | } """.stripMargin diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 1ccf0e3d06..ec2b9ab2cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -199,7 +199,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { // These benchmark are skipped in normal build ignore("benchmark") { // testWholeStage(200 << 20) - // testStddev(20 << 20) + // testStatFunctions(20 << 20) // testAggregateWithKey(20 << 20) // testBytesToBytesMap(1024 * 1024 * 50) } -- cgit v1.2.3