aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-02-02 22:13:10 -0800
committerDavies Liu <davies.liu@gmail.com>2016-02-02 22:13:10 -0800
commite86f8f63bfa3c15659b94e831b853b1bc9ddae32 (patch)
tree3e0de83a60c33c153d6e2876121dfe2e196d5108
parent335f10edad8c759bad3dbd0660ed4dd5d70ddd8b (diff)
downloadspark-e86f8f63bfa3c15659b94e831b853b1bc9ddae32.tar.gz
spark-e86f8f63bfa3c15659b94e831b853b1bc9ddae32.tar.bz2
spark-e86f8f63bfa3c15659b94e831b853b1bc9ddae32.zip
[SPARK-13147] [SQL] improve readability of generated code
1. try to avoid the suffix (unique id) 2. remove the comment if there is no code generated. 3. re-arrange the order of functions 4. trop the new line for inlined blocks. Author: Davies Liu <davies@databricks.com> Closes #11032 from davies/better_suffix.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala27
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala31
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala2
7 files changed, 63 insertions, 39 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 353fb92581..c73b2f8f2a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -103,8 +103,12 @@ abstract class Expression extends TreeNode[Expression] {
val value = ctx.freshName("value")
val ve = ExprCode("", isNull, value)
ve.code = genCode(ctx, ve)
- // Add `this` in the comment.
- ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim)
+ if (ve.code != "") {
+ // Add `this` in the comment.
+ ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim)
+ } else {
+ ve
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index a30aba1617..63e19564dd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -156,7 +156,11 @@ class CodegenContext {
/** The variable name of the input row in generated code. */
final var INPUT_ROW = "i"
- private val curId = new java.util.concurrent.atomic.AtomicInteger()
+ /**
+ * The map from a variable name to it's next ID.
+ */
+ private val freshNameIds = new mutable.HashMap[String, Int]
+ freshNameIds += INPUT_ROW -> 1
/**
* A prefix used to generate fresh name.
@@ -164,16 +168,21 @@ class CodegenContext {
var freshNamePrefix = ""
/**
- * Returns a term name that is unique within this instance of a `CodeGenerator`.
- *
- * (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
- * function.)
+ * Returns a term name that is unique within this instance of a `CodegenContext`.
*/
- def freshName(name: String): String = {
- if (freshNamePrefix == "") {
- s"$name${curId.getAndIncrement}"
+ def freshName(name: String): String = synchronized {
+ val fullName = if (freshNamePrefix == "") {
+ name
+ } else {
+ s"${freshNamePrefix}_$name"
+ }
+ if (freshNameIds.contains(fullName)) {
+ val id = freshNameIds(fullName)
+ freshNameIds(fullName) = id + 1
+ s"$fullName$id"
} else {
- s"${freshNamePrefix}_$name${curId.getAndIncrement}"
+ freshNameIds += fullName -> 1
+ fullName
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 9f2f82d68c..6b24fae9f3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -173,22 +173,26 @@ case class GetArrayStructFields(
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val arrayClass = classOf[GenericArrayData].getName
nullSafeCodeGen(ctx, ev, eval => {
+ val n = ctx.freshName("n")
+ val values = ctx.freshName("values")
+ val j = ctx.freshName("j")
+ val row = ctx.freshName("row")
s"""
- final int n = $eval.numElements();
- final Object[] values = new Object[n];
- for (int j = 0; j < n; j++) {
- if ($eval.isNullAt(j)) {
- values[j] = null;
+ final int $n = $eval.numElements();
+ final Object[] $values = new Object[$n];
+ for (int $j = 0; $j < $n; $j++) {
+ if ($eval.isNullAt($j)) {
+ $values[$j] = null;
} else {
- final InternalRow row = $eval.getStruct(j, $numFields);
- if (row.isNullAt($ordinal)) {
- values[j] = null;
+ final InternalRow $row = $eval.getStruct($j, $numFields);
+ if ($row.isNullAt($ordinal)) {
+ $values[$j] = null;
} else {
- values[j] = ${ctx.getValue("row", field.dataType, ordinal.toString)};
+ $values[$j] = ${ctx.getValue(row, field.dataType, ordinal.toString)};
}
}
}
- ${ev.value} = new $arrayClass(values);
+ ${ev.value} = new $arrayClass($values);
"""
})
}
@@ -227,12 +231,13 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
+ val index = ctx.freshName("index")
s"""
- final int index = (int) $eval2;
- if (index >= $eval1.numElements() || index < 0 || $eval1.isNullAt(index)) {
+ final int $index = (int) $eval2;
+ if ($index >= $eval1.numElements() || $index < 0 || $eval1.isNullAt($index)) {
${ev.isNull} = true;
} else {
- ${ev.value} = ${ctx.getValue(eval1, dataType, "index")};
+ ${ev.value} = ${ctx.getValue(eval1, dataType, index)};
}
"""
})
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 02b0f423ed..1475496907 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -170,8 +170,8 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
s"""
| while (input.hasNext()) {
| InternalRow $row = (InternalRow) input.next();
- | ${columns.map(_.code).mkString("\n")}
- | ${consume(ctx, columns)}
+ | ${columns.map(_.code).mkString("\n").trim}
+ | ${consume(ctx, columns).trim}
| }
""".stripMargin
}
@@ -236,15 +236,16 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
private Object[] references;
${ctx.declareMutableStates()}
- ${ctx.declareAddedFunctions()}
public GeneratedIterator(Object[] references) {
- this.references = references;
- ${ctx.initMutableStates()}
+ this.references = references;
+ ${ctx.initMutableStates()}
}
+ ${ctx.declareAddedFunctions()}
+
protected void processNext() throws java.io.IOException {
- $code
+ ${code.trim}
}
}
"""
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index f61db8594d..d024477061 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -211,9 +211,9 @@ case class TungstenAggregate(
| $doAgg();
|
| // output the result
- | $genResult
+ | ${genResult.trim}
|
- | ${consume(ctx, resultVars)}
+ | ${consume(ctx, resultVars).trim}
| }
""".stripMargin
}
@@ -242,9 +242,9 @@ case class TungstenAggregate(
}
s"""
| // do aggregate
- | ${aggVals.map(_.code).mkString("\n")}
+ | ${aggVals.map(_.code).mkString("\n").trim}
| // update aggregation buffer
- | ${updates.mkString("")}
+ | ${updates.mkString("\n").trim}
""".stripMargin
}
@@ -523,7 +523,7 @@ case class TungstenAggregate(
// Finally, sort the spilled aggregate buffers by key, and merge them together for same key.
s"""
// generate grouping key
- ${keyCode.code}
+ ${keyCode.code.trim}
UnsafeRow $buffer = null;
if ($checkFallback) {
// try to get the buffer from hash map
@@ -547,9 +547,9 @@ case class TungstenAggregate(
$incCounter
// evaluate aggregate function
- ${evals.map(_.code).mkString("\n")}
+ ${evals.map(_.code).mkString("\n").trim}
// update aggregate buffer
- ${updates.mkString("\n")}
+ ${updates.mkString("\n").trim}
"""
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index fd81531c93..ae4422195c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -93,9 +93,14 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
BindReferences.bindReference(condition, child.output))
ctx.currentVars = input
val eval = expr.gen(ctx)
+ val nullCheck = if (expr.nullable) {
+ s"!${eval.isNull} &&"
+ } else {
+ s""
+ }
s"""
| ${eval.code}
- | if (!${eval.isNull} && ${eval.value}) {
+ | if ($nullCheck ${eval.value}) {
| ${consume(ctx, ctx.currentVars)}
| }
""".stripMargin
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 1ccf0e3d06..ec2b9ab2cb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -199,7 +199,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
// These benchmark are skipped in normal build
ignore("benchmark") {
// testWholeStage(200 << 20)
- // testStddev(20 << 20)
+ // testStatFunctions(20 << 20)
// testAggregateWithKey(20 << 20)
// testBytesToBytesMap(1024 * 1024 * 50)
}