aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala228
1 files changed, 126 insertions, 102 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 1b13c8fd22..447dbe7018 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution
-import org.apache.spark.broadcast
+import org.apache.spark.{broadcast, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
@@ -29,10 +29,11 @@ import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin}
import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
/**
- * An interface for those physical operators that support codegen.
- */
+ * An interface for those physical operators that support codegen.
+ */
trait CodegenSupport extends SparkPlan {
/** Prefix used in the current operator's variable names. */
@@ -46,10 +47,10 @@ trait CodegenSupport extends SparkPlan {
}
/**
- * Creates a metric using the specified name.
- *
- * @return name of the variable representing the metric
- */
+ * Creates a metric using the specified name.
+ *
+ * @return name of the variable representing the metric
+ */
def metricTerm(ctx: CodegenContext, name: String): String = {
val metric = ctx.addReferenceObj(name, longMetric(name))
val value = ctx.freshName("metricValue")
@@ -59,25 +60,25 @@ trait CodegenSupport extends SparkPlan {
}
/**
- * Whether this SparkPlan support whole stage codegen or not.
- */
+ * Whether this SparkPlan support whole stage codegen or not.
+ */
def supportCodegen: Boolean = true
/**
- * Which SparkPlan is calling produce() of this one. It's itself for the first SparkPlan.
- */
+ * Which SparkPlan is calling produce() of this one. It's itself for the first SparkPlan.
+ */
protected var parent: CodegenSupport = null
/**
- * Returns all the RDDs of InternalRow which generates the input rows.
- *
- * Note: right now we support up to two RDDs.
- */
+ * Returns all the RDDs of InternalRow which generates the input rows.
+ *
+ * Note: right now we support up to two RDDs.
+ */
def upstreams(): Seq[RDD[InternalRow]]
/**
- * Returns Java source code to process the rows from upstream.
- */
+ * Returns Java source code to process the rows from upstream.
+ */
final def produce(ctx: CodegenContext, parent: CodegenSupport): String = {
this.parent = parent
ctx.freshNamePrefix = variablePrefix
@@ -89,28 +90,28 @@ trait CodegenSupport extends SparkPlan {
}
/**
- * Generate the Java source code to process, should be overridden by subclass to support codegen.
- *
- * doProduce() usually generate the framework, for example, aggregation could generate this:
- *
- * if (!initialized) {
- * # create a hash map, then build the aggregation hash map
- * # call child.produce()
- * initialized = true;
- * }
- * while (hashmap.hasNext()) {
- * row = hashmap.next();
- * # build the aggregation results
- * # create variables for results
- * # call consume(), which will call parent.doConsume()
+ * Generate the Java source code to process, should be overridden by subclass to support codegen.
+ *
+ * doProduce() usually generate the framework, for example, aggregation could generate this:
+ *
+ * if (!initialized) {
+ * # create a hash map, then build the aggregation hash map
+ * # call child.produce()
+ * initialized = true;
+ * }
+ * while (hashmap.hasNext()) {
+ * row = hashmap.next();
+ * # build the aggregation results
+ * # create variables for results
+ * # call consume(), which will call parent.doConsume()
* if (shouldStop()) return;
- * }
- */
+ * }
+ */
protected def doProduce(ctx: CodegenContext): String
/**
- * Consume the generated columns or row from current SparkPlan, call it's parent's doConsume().
- */
+ * Consume the generated columns or row from current SparkPlan, call it's parent's doConsume().
+ */
final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String = {
val inputVars =
if (row != null) {
@@ -152,15 +153,15 @@ trait CodegenSupport extends SparkPlan {
s"""
|
|/*** CONSUME: ${toCommentSafeString(parent.simpleString)} */
- |${evaluated}
+ |$evaluated
|${parent.doConsume(ctx, inputVars, rowVar)}
""".stripMargin
}
/**
- * Returns source code to evaluate all the variables, and clear the code of them, to prevent
- * them to be evaluated twice.
- */
+ * Returns source code to evaluate all the variables, and clear the code of them, to prevent
+ * them to be evaluated twice.
+ */
protected def evaluateVariables(variables: Seq[ExprCode]): String = {
val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n")
variables.foreach(_.code = "")
@@ -168,21 +169,21 @@ trait CodegenSupport extends SparkPlan {
}
/**
- * Returns source code to evaluate the variables for required attributes, and clear the code
- * of evaluated variables, to prevent them to be evaluated twice..
- */
+ * Returns source code to evaluate the variables for required attributes, and clear the code
+ * of evaluated variables, to prevent them to be evaluated twice.
+ */
protected def evaluateRequiredVariables(
attributes: Seq[Attribute],
variables: Seq[ExprCode],
required: AttributeSet): String = {
- var evaluateVars = ""
+ val evaluateVars = new StringBuilder
variables.zipWithIndex.foreach { case (ev, i) =>
if (ev.code != "" && required.contains(attributes(i))) {
- evaluateVars += ev.code.trim + "\n"
+ evaluateVars.append(ev.code.trim + "\n")
ev.code = ""
}
}
- evaluateVars
+ evaluateVars.toString()
}
/**
@@ -194,18 +195,18 @@ trait CodegenSupport extends SparkPlan {
def usedInputs: AttributeSet = references
/**
- * Generate the Java source code to process the rows from child SparkPlan.
- *
- * This should be override by subclass to support codegen.
- *
- * For example, Filter will generate the code like this:
- *
- * # code to evaluate the predicate expression, result is isNull1 and value2
- * if (isNull1 || !value2) continue;
- * # call consume(), which will call parent.doConsume()
- *
- * Note: A plan can either consume the rows as UnsafeRow (row), or a list of variables (input).
- */
+ * Generate the Java source code to process the rows from child SparkPlan.
+ *
+ * This should be override by subclass to support codegen.
+ *
+ * For example, Filter will generate the code like this:
+ *
+ * # code to evaluate the predicate expression, result is isNull1 and value2
+ * if (isNull1 || !value2) continue;
+ * # call consume(), which will call parent.doConsume()
+ *
+ * Note: A plan can either consume the rows as UnsafeRow (row), or a list of variables (input).
+ */
def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
throw new UnsupportedOperationException
}
@@ -213,11 +214,11 @@ trait CodegenSupport extends SparkPlan {
/**
- * InputAdapter is used to hide a SparkPlan from a subtree that support codegen.
- *
- * This is the leaf node of a tree with WholeStageCodegen, is used to generate code that consumes
- * an RDD iterator of InternalRow.
- */
+ * InputAdapter is used to hide a SparkPlan from a subtree that support codegen.
+ *
+ * This is the leaf node of a tree with WholeStageCodegen, is used to generate code that consumes
+ * an RDD iterator of InternalRow.
+ */
case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport {
override def output: Seq[Attribute] = child.output
@@ -260,33 +261,33 @@ object WholeStageCodegen {
}
/**
- * WholeStageCodegen compile a subtree of plans that support codegen together into single Java
- * function.
- *
- * Here is the call graph of to generate Java source (plan A support codegen, but plan B does not):
- *
- * WholeStageCodegen Plan A FakeInput Plan B
- * =========================================================================
- *
- * -> execute()
- * |
- * doExecute() ---------> upstreams() -------> upstreams() ------> execute()
- * |
- * +-----------------> produce()
- * |
- * doProduce() -------> produce()
- * |
- * doProduce()
- * |
- * doConsume() <--------- consume()
- * |
- * doConsume() <-------- consume()
- *
- * SparkPlan A should override doProduce() and doConsume().
- *
- * doCodeGen() will create a CodeGenContext, which will hold a list of variables for input,
- * used to generated code for BoundReference.
- */
+ * WholeStageCodegen compile a subtree of plans that support codegen together into single Java
+ * function.
+ *
+ * Here is the call graph of to generate Java source (plan A support codegen, but plan B does not):
+ *
+ * WholeStageCodegen Plan A FakeInput Plan B
+ * =========================================================================
+ *
+ * -> execute()
+ * |
+ * doExecute() ---------> upstreams() -------> upstreams() ------> execute()
+ * |
+ * +-----------------> produce()
+ * |
+ * doProduce() -------> produce()
+ * |
+ * doProduce()
+ * |
+ * doConsume() <--------- consume()
+ * |
+ * doConsume() <-------- consume()
+ *
+ * SparkPlan A should override doProduce() and doConsume().
+ *
+ * doCodeGen() will create a CodeGenContext, which will hold a list of variables for input,
+ * used to generated code for BoundReference.
+ */
case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSupport {
override def output: Seq[Attribute] = child.output
@@ -297,18 +298,22 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
"pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext,
WholeStageCodegen.PIPELINE_DURATION_METRIC))
- override def doExecute(): RDD[InternalRow] = {
+ /**
+ * Generates code for this subtree.
+ *
+ * @return the tuple of the codegen context and the actual generated source.
+ */
+ def doCodeGen(): (CodegenContext, String) = {
val ctx = new CodegenContext
val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
- val references = ctx.references.toArray
val source = s"""
public Object generate(Object[] references) {
return new GeneratedIterator(references);
}
/** Codegened pipeline for:
- * ${toCommentSafeString(child.treeString.trim)}
- */
+ * ${toCommentSafeString(child.treeString.trim)}
+ */
final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
private Object[] references;
@@ -318,7 +323,8 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
this.references = references;
}
- public void init(scala.collection.Iterator inputs[]) {
+ public void init(int index, scala.collection.Iterator inputs[]) {
+ partitionIndex = index;
${ctx.initMutableStates()}
}
@@ -332,18 +338,24 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
// try to compile, helpful for debug
val cleanedSource = CodeFormatter.stripExtraNewLines(source)
- logDebug(s"${CodeFormatter.format(cleanedSource)}")
+ logDebug(s"\n${CodeFormatter.format(cleanedSource)}")
CodeGenerator.compile(cleanedSource)
+ (ctx, cleanedSource)
+ }
+
+ override def doExecute(): RDD[InternalRow] = {
+ val (ctx, cleanedSource) = doCodeGen()
+ val references = ctx.references.toArray
val durationMs = longMetric("pipelineTime")
val rdds = child.asInstanceOf[CodegenSupport].upstreams()
assert(rdds.size <= 2, "Up to two upstream RDDs can be supported")
if (rdds.length == 1) {
- rdds.head.mapPartitions { iter =>
+ rdds.head.mapPartitionsWithIndex { (index, iter) =>
val clazz = CodeGenerator.compile(cleanedSource)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
- buffer.init(Array(iter))
+ buffer.init(index, Array(iter))
new Iterator[InternalRow] {
override def hasNext: Boolean = {
val v = buffer.hasNext
@@ -356,9 +368,10 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
} else {
// Right now, we support up to two upstreams.
rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) =>
+ val partitionIndex = TaskContext.getPartitionId()
val clazz = CodeGenerator.compile(cleanedSource)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
- buffer.init(Array(leftIter, rightIter))
+ buffer.init(partitionIndex, Array(leftIter, rightIter))
new Iterator[InternalRow] {
override def hasNext: Boolean = {
val v = buffer.hasNext
@@ -409,8 +422,8 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
/**
- * Find the chained plans that support codegen, collapse them together as WholeStageCodegen.
- */
+ * Find the chained plans that support codegen, collapse them together as WholeStageCodegen.
+ */
case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
private def supportCodegen(e: Expression): Boolean = e match {
@@ -421,12 +434,23 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
case _ => true
}
+ private def numOfNestedFields(dataType: DataType): Int = dataType match {
+ case dt: StructType => dt.fields.map(f => numOfNestedFields(f.dataType)).sum
+ case m: MapType => numOfNestedFields(m.keyType) + numOfNestedFields(m.valueType)
+ case a: ArrayType => numOfNestedFields(a.elementType)
+ case u: UserDefinedType[_] => numOfNestedFields(u.sqlType)
+ case _ => 1
+ }
+
private def supportCodegen(plan: SparkPlan): Boolean = plan match {
case plan: CodegenSupport if plan.supportCodegen =>
val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined)
// the generated code will be huge if there are too many columns
- val haveManyColumns = plan.output.length > 200
- !willFallback && !haveManyColumns
+ val hasTooManyOutputFields =
+ numOfNestedFields(plan.schema) > conf.wholeStageMaxNumFields
+ val hasTooManyInputFields =
+ plan.children.map(p => numOfNestedFields(p.schema)).exists(_ > conf.wholeStageMaxNumFields)
+ !willFallback && !hasTooManyOutputFields && !hasTooManyInputFields
case _ => false
}