diff options
Diffstat (limited to 'sql/core')
8 files changed, 43 insertions, 23 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index 524285bc87..a84e180ad1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -93,7 +93,7 @@ case class Expand( child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { /* * When the projections list looks like: * expr1A, exprB, expr1C diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala index 2ea889ea72..5a67cd0c24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala @@ -105,6 +105,8 @@ case class Sort( // Name of sorter variable used in codegen. private var sorterVariable: String = _ + override def preferUnsafeRow: Boolean = true + override protected def doProduce(ctx: CodegenContext): String = { val needToSort = ctx.freshName("needToSort") ctx.addMutableState("boolean", needToSort, s"$needToSort = true;") @@ -153,18 +155,22 @@ case class Sort( """.stripMargin.trim } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { - val colExprs = child.output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable) - } + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { + if (row != null) { + s"$sorterVariable.insertRow((UnsafeRow)$row);" + } else { + val colExprs = child.output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } - ctx.currentVars = input - val code = GenerateUnsafeProjection.createCode(ctx, colExprs) + ctx.currentVars = input + val code = GenerateUnsafeProjection.createCode(ctx, colExprs) - s""" - | // Convert the input attributes to an UnsafeRow and add it to the sorter - | ${code.code} - | $sorterVariable.insertRow(${code.value}); - """.stripMargin.trim + s""" + | // Convert the input attributes to an UnsafeRow and add it to the sorter + | ${code.code} + | $sorterVariable.insertRow(${code.value}); + """.stripMargin.trim + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index dd831e60cb..e8e42d72d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -65,7 +65,12 @@ trait CodegenSupport extends SparkPlan { /** * Which SparkPlan is calling produce() of this one. It's itself for the first SparkPlan. */ - private var parent: CodegenSupport = null + protected var parent: CodegenSupport = null + + /** + * Whether this SparkPlan prefers to accept UnsafeRow as input in doConsume. + */ + def preferUnsafeRow: Boolean = false /** * Returns all the RDDs of InternalRow which generates the input rows. @@ -176,11 +181,20 @@ trait CodegenSupport extends SparkPlan { } else { input } + + val evaluated = + if (row != null && preferUnsafeRow) { + // Current plan can consume UnsafeRows directly. + "" + } else { + evaluateRequiredVariables(child.output, inputVars, usedInputs) + } + s""" | |/*** CONSUME: ${toCommentSafeString(this.simpleString)} */ - |${evaluateRequiredVariables(child.output, inputVars, usedInputs)} - |${doConsume(ctx, inputVars)} + |${evaluated} + |${doConsume(ctx, inputVars, row)} """.stripMargin } @@ -195,7 +209,7 @@ trait CodegenSupport extends SparkPlan { * if (isNull1 || !value2) continue; * # call consume(), which will call parent.doConsume() */ - protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { throw new UnsupportedOperationException } } @@ -238,7 +252,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport s""" | while (!shouldStop() && $input.hasNext()) { | InternalRow $row = (InternalRow) $input.next(); - | ${consume(ctx, columns).trim} + | ${consume(ctx, columns, row).trim} | } """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index f856634cf7..1c4d594cd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -139,7 +139,7 @@ case class TungstenAggregate( } } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { if (groupingExpressions.isEmpty) { doConsumeWithoutKeys(ctx, input) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 4901298227..6ebbc8be6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -49,7 +49,7 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) references.filter(a => usedMoreThanOnce.contains(a.exprId)) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { val exprs = projectList.map(x => ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) ctx.currentVars = input @@ -88,7 +88,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { val numOutput = metricTerm(ctx, "numOutputRows") val expr = ExpressionCanonicalizer.execute( BindReferences.bindReference(condition, child.output)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index fed88b8c0a..034bf15262 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -136,7 +136,7 @@ package object debug { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { consume(ctx, input) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index c52662a61e..4c8f8080a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -107,7 +107,7 @@ case class BroadcastHashJoin( streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { if (joinType == Inner) { codegenInner(ctx, input) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 5a7516b7f9..ca624a5a84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -65,7 +65,7 @@ trait BaseLimit extends UnaryNode with CodegenSupport { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { val stopEarly = ctx.freshName("stopEarly") ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;") |