aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2016-03-10 10:04:56 -0800
committerDavies Liu <davies.liu@gmail.com>2016-03-10 10:04:56 -0800
commitd24801ad285ac3f2282fe20d1250a010673e2f96 (patch)
tree50f17954052f4d70b79ea6eec890d04efa35f2ef /sql/core/src
parent74267beb3546d316c659499a9ff577437541f072 (diff)
downloadspark-d24801ad285ac3f2282fe20d1250a010673e2f96.tar.gz
spark-d24801ad285ac3f2282fe20d1250a010673e2f96.tar.bz2
spark-d24801ad285ac3f2282fe20d1250a010673e2f96.zip
[SPARK-13636] [SQL] Directly consume UnsafeRow in wholestage codegen plans
JIRA: https://issues.apache.org/jira/browse/SPARK-13636 ## What changes were proposed in this pull request? As shown in the wholestage codegen verion of Sort operator, when Sort is top of Exchange (or other operator that produce UnsafeRow), we will create variables from UnsafeRow, than create another UnsafeRow using these variables. We should avoid the unnecessary unpack and pack variables from UnsafeRows. ## How was this patch tested? All existing wholestage codegen tests should be passed. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #11484 from viirya/direct-consume-unsaferow.
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala28
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala24
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala2
8 files changed, 43 insertions, 23 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
index 524285bc87..a84e180ad1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
@@ -93,7 +93,7 @@ case class Expand(
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
/*
* When the projections list looks like:
* expr1A, exprB, expr1C
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
index 2ea889ea72..5a67cd0c24 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
@@ -105,6 +105,8 @@ case class Sort(
// Name of sorter variable used in codegen.
private var sorterVariable: String = _
+ override def preferUnsafeRow: Boolean = true
+
override protected def doProduce(ctx: CodegenContext): String = {
val needToSort = ctx.freshName("needToSort")
ctx.addMutableState("boolean", needToSort, s"$needToSort = true;")
@@ -153,18 +155,22 @@ case class Sort(
""".stripMargin.trim
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
- val colExprs = child.output.zipWithIndex.map { case (attr, i) =>
- BoundReference(i, attr.dataType, attr.nullable)
- }
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
+ if (row != null) {
+ s"$sorterVariable.insertRow((UnsafeRow)$row);"
+ } else {
+ val colExprs = child.output.zipWithIndex.map { case (attr, i) =>
+ BoundReference(i, attr.dataType, attr.nullable)
+ }
- ctx.currentVars = input
- val code = GenerateUnsafeProjection.createCode(ctx, colExprs)
+ ctx.currentVars = input
+ val code = GenerateUnsafeProjection.createCode(ctx, colExprs)
- s"""
- | // Convert the input attributes to an UnsafeRow and add it to the sorter
- | ${code.code}
- | $sorterVariable.insertRow(${code.value});
- """.stripMargin.trim
+ s"""
+ | // Convert the input attributes to an UnsafeRow and add it to the sorter
+ | ${code.code}
+ | $sorterVariable.insertRow(${code.value});
+ """.stripMargin.trim
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index dd831e60cb..e8e42d72d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -65,7 +65,12 @@ trait CodegenSupport extends SparkPlan {
/**
* Which SparkPlan is calling produce() of this one. It's itself for the first SparkPlan.
*/
- private var parent: CodegenSupport = null
+ protected var parent: CodegenSupport = null
+
+ /**
+ * Whether this SparkPlan prefers to accept UnsafeRow as input in doConsume.
+ */
+ def preferUnsafeRow: Boolean = false
/**
* Returns all the RDDs of InternalRow which generates the input rows.
@@ -176,11 +181,20 @@ trait CodegenSupport extends SparkPlan {
} else {
input
}
+
+ val evaluated =
+ if (row != null && preferUnsafeRow) {
+ // Current plan can consume UnsafeRows directly.
+ ""
+ } else {
+ evaluateRequiredVariables(child.output, inputVars, usedInputs)
+ }
+
s"""
|
|/*** CONSUME: ${toCommentSafeString(this.simpleString)} */
- |${evaluateRequiredVariables(child.output, inputVars, usedInputs)}
- |${doConsume(ctx, inputVars)}
+ |${evaluated}
+ |${doConsume(ctx, inputVars, row)}
""".stripMargin
}
@@ -195,7 +209,7 @@ trait CodegenSupport extends SparkPlan {
* if (isNull1 || !value2) continue;
* # call consume(), which will call parent.doConsume()
*/
- protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
throw new UnsupportedOperationException
}
}
@@ -238,7 +252,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport
s"""
| while (!shouldStop() && $input.hasNext()) {
| InternalRow $row = (InternalRow) $input.next();
- | ${consume(ctx, columns).trim}
+ | ${consume(ctx, columns, row).trim}
| }
""".stripMargin
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index f856634cf7..1c4d594cd8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -139,7 +139,7 @@ case class TungstenAggregate(
}
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
if (groupingExpressions.isEmpty) {
doConsumeWithoutKeys(ctx, input)
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 4901298227..6ebbc8be6f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -49,7 +49,7 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
references.filter(a => usedMoreThanOnce.contains(a.exprId))
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
val exprs = projectList.map(x =>
ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output)))
ctx.currentVars = input
@@ -88,7 +88,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
val numOutput = metricTerm(ctx, "numOutputRows")
val expr = ExpressionCanonicalizer.execute(
BindReferences.bindReference(condition, child.output))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index fed88b8c0a..034bf15262 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -136,7 +136,7 @@ package object debug {
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
consume(ctx, input)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index c52662a61e..4c8f8080a9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -107,7 +107,7 @@ case class BroadcastHashJoin(
streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
if (joinType == Inner) {
codegenInner(ctx, input)
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index 5a7516b7f9..ca624a5a84 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -65,7 +65,7 @@ trait BaseLimit extends UnaryNode with CodegenSupport {
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
val stopEarly = ctx.freshName("stopEarly")
ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;")