aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala28
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala24
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala2
8 files changed, 43 insertions, 23 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
index 524285bc87..a84e180ad1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
@@ -93,7 +93,7 @@ case class Expand(
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
/*
* When the projections list looks like:
* expr1A, exprB, expr1C
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
index 2ea889ea72..5a67cd0c24 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
@@ -105,6 +105,8 @@ case class Sort(
// Name of sorter variable used in codegen.
private var sorterVariable: String = _
+ override def preferUnsafeRow: Boolean = true
+
override protected def doProduce(ctx: CodegenContext): String = {
val needToSort = ctx.freshName("needToSort")
ctx.addMutableState("boolean", needToSort, s"$needToSort = true;")
@@ -153,18 +155,22 @@ case class Sort(
""".stripMargin.trim
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
- val colExprs = child.output.zipWithIndex.map { case (attr, i) =>
- BoundReference(i, attr.dataType, attr.nullable)
- }
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
+ if (row != null) {
+ s"$sorterVariable.insertRow((UnsafeRow)$row);"
+ } else {
+ val colExprs = child.output.zipWithIndex.map { case (attr, i) =>
+ BoundReference(i, attr.dataType, attr.nullable)
+ }
- ctx.currentVars = input
- val code = GenerateUnsafeProjection.createCode(ctx, colExprs)
+ ctx.currentVars = input
+ val code = GenerateUnsafeProjection.createCode(ctx, colExprs)
- s"""
- | // Convert the input attributes to an UnsafeRow and add it to the sorter
- | ${code.code}
- | $sorterVariable.insertRow(${code.value});
- """.stripMargin.trim
+ s"""
+ | // Convert the input attributes to an UnsafeRow and add it to the sorter
+ | ${code.code}
+ | $sorterVariable.insertRow(${code.value});
+ """.stripMargin.trim
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index dd831e60cb..e8e42d72d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -65,7 +65,12 @@ trait CodegenSupport extends SparkPlan {
/**
* Which SparkPlan is calling produce() of this one. It's itself for the first SparkPlan.
*/
- private var parent: CodegenSupport = null
+ protected var parent: CodegenSupport = null
+
+ /**
+ * Whether this SparkPlan prefers to accept UnsafeRow as input in doConsume.
+ */
+ def preferUnsafeRow: Boolean = false
/**
* Returns all the RDDs of InternalRow which generates the input rows.
@@ -176,11 +181,20 @@ trait CodegenSupport extends SparkPlan {
} else {
input
}
+
+ val evaluated =
+ if (row != null && preferUnsafeRow) {
+ // Current plan can consume UnsafeRows directly.
+ ""
+ } else {
+ evaluateRequiredVariables(child.output, inputVars, usedInputs)
+ }
+
s"""
|
|/*** CONSUME: ${toCommentSafeString(this.simpleString)} */
- |${evaluateRequiredVariables(child.output, inputVars, usedInputs)}
- |${doConsume(ctx, inputVars)}
+ |${evaluated}
+ |${doConsume(ctx, inputVars, row)}
""".stripMargin
}
@@ -195,7 +209,7 @@ trait CodegenSupport extends SparkPlan {
* if (isNull1 || !value2) continue;
* # call consume(), which will call parent.doConsume()
*/
- protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
throw new UnsupportedOperationException
}
}
@@ -238,7 +252,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport
s"""
| while (!shouldStop() && $input.hasNext()) {
| InternalRow $row = (InternalRow) $input.next();
- | ${consume(ctx, columns).trim}
+ | ${consume(ctx, columns, row).trim}
| }
""".stripMargin
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index f856634cf7..1c4d594cd8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -139,7 +139,7 @@ case class TungstenAggregate(
}
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
if (groupingExpressions.isEmpty) {
doConsumeWithoutKeys(ctx, input)
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 4901298227..6ebbc8be6f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -49,7 +49,7 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
references.filter(a => usedMoreThanOnce.contains(a.exprId))
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
val exprs = projectList.map(x =>
ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output)))
ctx.currentVars = input
@@ -88,7 +88,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
val numOutput = metricTerm(ctx, "numOutputRows")
val expr = ExpressionCanonicalizer.execute(
BindReferences.bindReference(condition, child.output))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index fed88b8c0a..034bf15262 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -136,7 +136,7 @@ package object debug {
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
consume(ctx, input)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index c52662a61e..4c8f8080a9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -107,7 +107,7 @@ case class BroadcastHashJoin(
streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
if (joinType == Inner) {
codegenInner(ctx, input)
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index 5a7516b7f9..ca624a5a84 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -65,7 +65,7 @@ trait BaseLimit extends UnaryNode with CodegenSupport {
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
val stopEarly = ctx.freshName("stopEarly")
ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;")