aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-03-26 11:03:05 -0700
committerReynold Xin <rxin@databricks.com>2016-03-26 11:03:05 -0700
commitbd94ea4c80f4fc18f4000346d7c6717539846efb (patch)
tree2d031be41791cb1b7b6c814c7420fb50281402fb /sql
parenta91784fb6e47e2f72551e2379731e0a36fda9d04 (diff)
downloadspark-bd94ea4c80f4fc18f4000346d7c6717539846efb.tar.gz
spark-bd94ea4c80f4fc18f4000346d7c6717539846efb.tar.bz2
spark-bd94ea4c80f4fc18f4000346d7c6717539846efb.zip
[SPARK-14175][SQL] whole stage codegen interface refactor
## What changes were proposed in this pull request? 1. merge consumeChild into consume() 2. always generate code for input variables and UnsafeRow, a plan can use eight of them. ## How was this patch tested? Existing tests. Author: Davies Liu <davies@databricks.com> Closes #11975 from davies/gen_refactor.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala26
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala153
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala2
9 files changed, 72 insertions, 124 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 3e2c799762..815ff01c4c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -271,7 +271,8 @@ private[sql] case class DataSourceScan(
| }
| }""".stripMargin)
- val exprRows = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
+ val exprRows =
+ output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, x._1.nullable))
ctx.INPUT_ROW = row
ctx.currentVars = null
val columns2 = exprRows.map(_.gen(ctx))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
index 05627ba9c7..bd23b7e3ad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
@@ -93,7 +93,7 @@ case class Expand(
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
/*
* When the projections list looks like:
* expr1A, exprB, expr1C
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
index b4dd77041e..efd8760cd2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
@@ -98,6 +98,8 @@ case class Sort(
}
}
+ override def usedInputs: AttributeSet = AttributeSet(Seq.empty)
+
override def upstreams(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].upstreams()
}
@@ -105,8 +107,6 @@ case class Sort(
// Name of sorter variable used in codegen.
private var sorterVariable: String = _
- override def preferUnsafeRow: Boolean = true
-
override protected def doProduce(ctx: CodegenContext): String = {
val needToSort = ctx.freshName("needToSort")
ctx.addMutableState("boolean", needToSort, s"$needToSort = true;")
@@ -158,22 +158,10 @@ case class Sort(
""".stripMargin.trim
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
- if (row != null) {
- s"$sorterVariable.insertRow((UnsafeRow)$row);"
- } else {
- val colExprs = child.output.zipWithIndex.map { case (attr, i) =>
- BoundReference(i, attr.dataType, attr.nullable)
- }
-
- ctx.currentVars = input
- val code = GenerateUnsafeProjection.createCode(ctx, colExprs)
-
- s"""
- | // Convert the input attributes to an UnsafeRow and add it to the sorter
- | ${code.code}
- | $sorterVariable.insertRow(${code.value});
- """.stripMargin.trim
- }
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
+ s"""
+ |${row.code}
+ |$sorterVariable.insertRow((UnsafeRow)${row.value});
+ """.stripMargin
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 0be0b8032a..1b13c8fd22 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -69,11 +69,6 @@ trait CodegenSupport extends SparkPlan {
protected var parent: CodegenSupport = null
/**
- * Whether this SparkPlan prefers to accept UnsafeRow as input in doConsume.
- */
- def preferUnsafeRow: Boolean = false
-
- /**
* Returns all the RDDs of InternalRow which generates the input rows.
*
* Note: right now we support up to two RDDs.
@@ -114,13 +109,52 @@ trait CodegenSupport extends SparkPlan {
protected def doProduce(ctx: CodegenContext): String
/**
- * Consume the columns generated from current SparkPlan, call it's parent.
+ * Consume the generated columns or row from current SparkPlan, call it's parent's doConsume().
*/
- final def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = {
- if (input != null) {
- assert(input.length == output.length)
+ final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String = {
+ val inputVars =
+ if (row != null) {
+ ctx.currentVars = null
+ ctx.INPUT_ROW = row
+ output.zipWithIndex.map { case (attr, i) =>
+ BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
+ }
+ } else {
+ assert(outputVars != null)
+ assert(outputVars.length == output.length)
+ // outputVars will be used to generate the code for UnsafeRow, so we should copy them
+ outputVars.map(_.copy())
+ }
+ val rowVar = if (row != null) {
+ ExprCode("", "false", row)
+ } else {
+ if (outputVars.nonEmpty) {
+ val colExprs = output.zipWithIndex.map { case (attr, i) =>
+ BoundReference(i, attr.dataType, attr.nullable)
+ }
+ val evaluateInputs = evaluateVariables(outputVars)
+ // generate the code to create a UnsafeRow
+ ctx.currentVars = outputVars
+ val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
+ val code = s"""
+ |$evaluateInputs
+ |${ev.code.trim}
+ """.stripMargin.trim
+ ExprCode(code, "false", ev.value)
+ } else {
+ // There is no columns
+ ExprCode("", "false", "unsafeRow")
+ }
}
- parent.consumeChild(ctx, this, input, row)
+
+ ctx.freshNamePrefix = parent.variablePrefix
+ val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs)
+ s"""
+ |
+ |/*** CONSUME: ${toCommentSafeString(parent.simpleString)} */
+ |${evaluated}
+ |${parent.doConsume(ctx, inputVars, rowVar)}
+ """.stripMargin
}
/**
@@ -160,47 +194,6 @@ trait CodegenSupport extends SparkPlan {
def usedInputs: AttributeSet = references
/**
- * Consume the columns generated from its child, call doConsume() or emit the rows.
- *
- * An operator could generate variables for the output, or a row, either one could be null.
- *
- * If the row is not null, we create variables to access the columns that are actually used by
- * current plan before calling doConsume().
- */
- def consumeChild(
- ctx: CodegenContext,
- child: SparkPlan,
- input: Seq[ExprCode],
- row: String = null): String = {
- ctx.freshNamePrefix = variablePrefix
- val inputVars =
- if (row != null) {
- ctx.currentVars = null
- ctx.INPUT_ROW = row
- child.output.zipWithIndex.map { case (attr, i) =>
- BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
- }
- } else {
- input
- }
-
- val evaluated =
- if (row != null && preferUnsafeRow) {
- // Current plan can consume UnsafeRows directly.
- ""
- } else {
- evaluateRequiredVariables(child.output, inputVars, usedInputs)
- }
-
- s"""
- |
- |/*** CONSUME: ${toCommentSafeString(this.simpleString)} */
- |${evaluated}
- |${doConsume(ctx, inputVars, row)}
- """.stripMargin
- }
-
- /**
* Generate the Java source code to process the rows from child SparkPlan.
*
* This should be override by subclass to support codegen.
@@ -210,8 +203,10 @@ trait CodegenSupport extends SparkPlan {
* # code to evaluate the predicate expression, result is isNull1 and value2
* if (isNull1 || !value2) continue;
* # call consume(), which will call parent.doConsume()
+ *
+ * Note: A plan can either consume the rows as UnsafeRow (row), or a list of variables (input).
*/
- protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
+ def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
throw new UnsupportedOperationException
}
}
@@ -245,16 +240,11 @@ case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport
val input = ctx.freshName("input")
// Right now, InputAdapter is only used when there is one upstream.
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
-
- val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
val row = ctx.freshName("row")
- ctx.INPUT_ROW = row
- ctx.currentVars = null
- val columns = exprs.map(_.gen(ctx))
s"""
| while ($input.hasNext()) {
| InternalRow $row = (InternalRow) $input.next();
- | ${consume(ctx, columns, row).trim}
+ | ${consume(ctx, null, row).trim}
| if (shouldStop()) return;
| }
""".stripMargin
@@ -282,18 +272,15 @@ object WholeStageCodegen {
* |
* doExecute() ---------> upstreams() -------> upstreams() ------> execute()
* |
- * -----------------> produce()
+ * +-----------------> produce()
* |
* doProduce() -------> produce()
* |
* doProduce()
* |
- * consume()
- * consumeChild() <-----------|
+ * doConsume() <--------- consume()
* |
- * doConsume()
- * |
- * consumeChild() <----- consume()
+ * doConsume() <-------- consume()
*
* SparkPlan A should override doProduce() and doConsume().
*
@@ -392,44 +379,16 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
throw new UnsupportedOperationException
}
- override def consumeChild(
- ctx: CodegenContext,
- child: SparkPlan,
- input: Seq[ExprCode],
- row: String = null): String = {
-
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val doCopy = if (ctx.copyResult) {
".copy()"
} else {
""
}
- if (row != null) {
- // There is an UnsafeRow already
- s"""
- |append($row$doCopy);
- """.stripMargin.trim
- } else {
- assert(input != null)
- if (input.nonEmpty) {
- val colExprs = output.zipWithIndex.map { case (attr, i) =>
- BoundReference(i, attr.dataType, attr.nullable)
- }
- val evaluateInputs = evaluateVariables(input)
- // generate the code to create a UnsafeRow
- ctx.currentVars = input
- val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
- s"""
- |$evaluateInputs
- |${code.code.trim}
- |append(${code.value}$doCopy);
- """.stripMargin.trim
- } else {
- // There is no columns
- s"""
- |append(unsafeRow);
- """.stripMargin.trim
- }
- }
+ s"""
+ |${row.code}
+ |append(${row.value}$doCopy);
+ """.stripMargin.trim
}
override def innerChildren: Seq[SparkPlan] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 28945a507c..7c215d1b96 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -139,7 +139,7 @@ case class TungstenAggregate(
}
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
if (groupingExpressions.isEmpty) {
doConsumeWithoutKeys(ctx, input)
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index ee3f1d70e1..70e04d022f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -49,7 +49,7 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
references.filter(a => usedMoreThanOnce.contains(a.exprId))
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val exprs = projectList.map(x =>
ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output)))
ctx.currentVars = input
@@ -107,7 +107,7 @@ case class Filter(condition: Expression, child: SparkPlan)
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val numOutput = metricTerm(ctx, "numOutputRows")
// filter out the nulls
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index d5ce1243d9..5e573b3159 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -137,7 +137,7 @@ package object debug {
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
consume(ctx, input)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index aa2da283b1..f5b083c216 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -110,7 +110,7 @@ case class BroadcastHashJoin(
streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
joinType match {
case Inner => codegenInner(ctx, input)
case LeftOuter | RightOuter => codegenOuter(ctx, input)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index ca624a5a84..9643b52f96 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -65,7 +65,7 @@ trait BaseLimit extends UnaryNode with CodegenSupport {
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val stopEarly = ctx.freshName("stopEarly")
ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;")