aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-01-16 10:29:27 -0800
committerDavies Liu <davies.liu@gmail.com>2016-01-16 10:29:27 -0800
commit3c0d2365d57fc49ac9bf0d7cc9bd2ef633fb5fb6 (patch)
tree2e05f1fbb9eec2c870081b00b9e60505db07e1b9 /sql
parent86972fa52152d2149b88ba75be048a6986006285 (diff)
downloadspark-3c0d2365d57fc49ac9bf0d7cc9bd2ef633fb5fb6.tar.gz
spark-3c0d2365d57fc49ac9bf0d7cc9bd2ef633fb5fb6.tar.bz2
spark-3c0d2365d57fc49ac9bf0d7cc9bd2ef633fb5fb6.zip
[SPARK-12796] [SQL] Whole stage codegen
This is the initial work for whole stage codegen, it support Projection/Filter/Range, we will continue work on this to support more physical operators. A micro benchmark show that a query with range, filter and projection could be 3X faster then before. It's turned on by default. For a tree that have at least two chained plans, a WholeStageCodegen will be inserted into it, for example, the following plan ``` Limit 10 +- Project [(id#5L + 1) AS (id + 1)#6L] +- Filter ((id#5L & 1) = 1) +- Range 0, 1, 4, 10, [id#5L] ``` will be translated into ``` Limit 10 +- WholeStageCodegen +- Project [(id#1L + 1) AS (id + 1)#2L] +- Filter ((id#1L & 1) = 1) +- Range 0, 1, 4, 10, [id#1L] ``` Here is the call graph to generate Java source for A and B (A support codegen, but B does not): ``` * WholeStageCodegen Plan A FakeInput Plan B * ========================================================================= * * -> execute() * | * doExecute() --------> produce() * | * doProduce() -------> produce() * | * doProduce() ---> execute() * | * consume() * doConsume() ------------| * | * doConsume() <----- consume() ``` A SparkPlan that support codegen need to implement doProduce() and doConsume(): ``` def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String ``` Author: Davies Liu <davies@databricks.com> Closes #10735 from davies/whole2.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala76
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java64
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala299
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala114
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala60
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala38
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala6
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala8
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala6
37 files changed, 694 insertions, 107 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index dda822d054..4727ff1885 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -61,7 +61,11 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val javaType = ctx.javaType(dataType)
val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
- if (nullable) {
+ if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) {
+ ev.isNull = ctx.currentVars(ordinal).isNull
+ ev.value = ctx.currentVars(ordinal).value
+ ""
+ } else if (nullable) {
s"""
boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index f3a39a0e75..683029ff14 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -56,6 +56,12 @@ class CodegenContext {
val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]()
/**
+ * Holding a list of generated columns as input of current operator, will be used by
+ * BoundReference to generate code.
+ */
+ var currentVars: Seq[ExprCode] = null
+
+ /**
* Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a
* 3-tuple: java type, variable name, code to init it.
* As an example, ("int", "count", "count = 0;") will produce code:
@@ -77,6 +83,16 @@ class CodegenContext {
mutableStates += ((javaType, variableName, initCode))
}
+ def declareMutableStates(): String = {
+ mutableStates.map { case (javaType, variableName, _) =>
+ s"private $javaType $variableName;"
+ }.mkString("\n")
+ }
+
+ def initMutableStates(): String = {
+ mutableStates.map(_._3).mkString("\n")
+ }
+
/**
* Holding all the functions those will be added into generated class.
*/
@@ -111,6 +127,10 @@ class CodegenContext {
// The collection of sub-exression result resetting methods that need to be called on each row.
val subExprResetVariables = mutable.ArrayBuffer.empty[String]
+ def declareAddedFunctions(): String = {
+ addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n")
+ }
+
final val JAVA_BOOLEAN = "boolean"
final val JAVA_BYTE = "byte"
final val JAVA_SHORT = "short"
@@ -120,7 +140,7 @@ class CodegenContext {
final val JAVA_DOUBLE = "double"
/** The variable name of the input row in generated code. */
- final val INPUT_ROW = "i"
+ final var INPUT_ROW = "i"
private val curId = new java.util.concurrent.atomic.AtomicInteger()
@@ -476,20 +496,6 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
protected val genericMutableRowType: String = classOf[GenericMutableRow].getName
- protected def declareMutableStates(ctx: CodegenContext): String = {
- ctx.mutableStates.map { case (javaType, variableName, _) =>
- s"private $javaType $variableName;"
- }.mkString("\n")
- }
-
- protected def initMutableStates(ctx: CodegenContext): String = {
- ctx.mutableStates.map(_._3).mkString("\n")
- }
-
- protected def declareAddedFunctions(ctx: CodegenContext): String = {
- ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n").trim
- }
-
/**
* Generates a class for a given input expression. Called when there is not cached code
* already available.
@@ -505,16 +511,33 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
/** Binds an input expression to a given input schema */
protected def bind(in: InType, inputSchema: Seq[Attribute]): InType
+ /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */
+ def generate(expressions: InType, inputSchema: Seq[Attribute]): OutType =
+ generate(bind(expressions, inputSchema))
+
+ /** Generates the requested evaluator given already bound expression(s). */
+ def generate(expressions: InType): OutType = create(canonicalize(expressions))
+
/**
- * Compile the Java source code into a Java class, using Janino.
+ * Create a new codegen context for expression evaluator, used to store those
+ * expressions that don't support codegen
*/
- protected def compile(code: String): GeneratedClass = {
+ def newCodeGenContext(): CodegenContext = {
+ new CodegenContext
+ }
+}
+
+object CodeGenerator extends Logging {
+ /**
+ * Compile the Java source code into a Java class, using Janino.
+ */
+ def compile(code: String): GeneratedClass = {
cache.get(code)
}
/**
- * Compile the Java source code into a Java class, using Janino.
- */
+ * Compile the Java source code into a Java class, using Janino.
+ */
private[this] def doCompile(code: String): GeneratedClass = {
val evaluator = new ClassBodyEvaluator()
evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader)
@@ -577,19 +600,4 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
result
}
})
-
- /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */
- def generate(expressions: InType, inputSchema: Seq[Attribute]): OutType =
- generate(bind(expressions, inputSchema))
-
- /** Generates the requested evaluator given already bound expression(s). */
- def generate(expressions: InType): OutType = create(canonicalize(expressions))
-
- /**
- * Create a new codegen context for expression evaluator, used to store those
- * expressions that don't support codegen
- */
- def newCodeGenContext(): CodegenContext = {
- new CodegenContext
- }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
index cface21e5f..f58a2daf90 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen
-import org.apache.spark.sql.catalyst.expressions.{Expression, Nondeterministic}
+import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic}
/**
* A trait that can be used to provide a fallback mode for expression code generation.
@@ -30,13 +30,15 @@ trait CodegenFallback extends Expression {
case _ =>
}
+ // LeafNode does not need `input`
+ val input = if (this.isInstanceOf[LeafExpression]) "null" else ctx.INPUT_ROW
val idx = ctx.references.length
ctx.references += this
val objectTerm = ctx.freshName("obj")
if (nullable) {
s"""
/* expression: ${this.toCommentSafeString} */
- Object $objectTerm = ((Expression) references[$idx]).eval(${ctx.INPUT_ROW});
+ Object $objectTerm = ((Expression) references[$idx]).eval($input);
boolean ${ev.isNull} = $objectTerm == null;
${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)};
if (!${ev.isNull}) {
@@ -47,7 +49,7 @@ trait CodegenFallback extends Expression {
ev.isNull = "false"
s"""
/* expression: ${this.toCommentSafeString} */
- Object $objectTerm = ((Expression) references[$idx]).eval(${ctx.INPUT_ROW});
+ Object $objectTerm = ((Expression) references[$idx]).eval($input);
${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
"""
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 63d13a8b87..59ef0f5836 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -107,13 +107,13 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
private Object[] references;
private MutableRow mutableRow;
- ${declareMutableStates(ctx)}
- ${declareAddedFunctions(ctx)}
+ ${ctx.declareMutableStates()}
+ ${ctx.declareAddedFunctions()}
public SpecificMutableProjection(Object[] references) {
this.references = references;
mutableRow = new $genericMutableRowType(${expressions.size});
- ${initMutableStates(ctx)}
+ ${ctx.initMutableStates()}
}
public ${classOf[BaseMutableProjection].getName} target(MutableRow row) {
@@ -138,7 +138,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")
- val c = compile(code)
+ val c = CodeGenerator.compile(code)
() => {
c.generate(ctx.references.toArray).asInstanceOf[MutableProjection]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index e033f62170..6de57537ec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -118,12 +118,12 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
class SpecificOrdering extends ${classOf[BaseOrdering].getName} {
private Object[] references;
- ${declareMutableStates(ctx)}
- ${declareAddedFunctions(ctx)}
+ ${ctx.declareMutableStates()}
+ ${ctx.declareAddedFunctions()}
public SpecificOrdering(Object[] references) {
this.references = references;
- ${initMutableStates(ctx)}
+ ${ctx.initMutableStates()}
}
public int compare(InternalRow a, InternalRow b) {
@@ -135,6 +135,6 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
logDebug(s"Generated Ordering: ${CodeFormatter.format(code)}")
- compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering]
+ CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering]
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
index 6fbe12fc65..58065d956f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
@@ -47,12 +47,12 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
class SpecificPredicate extends ${classOf[Predicate].getName} {
private final Object[] references;
- ${declareMutableStates(ctx)}
- ${declareAddedFunctions(ctx)}
+ ${ctx.declareMutableStates()}
+ ${ctx.declareAddedFunctions()}
public SpecificPredicate(Object[] references) {
this.references = references;
- ${initMutableStates(ctx)}
+ ${ctx.initMutableStates()}
}
public boolean eval(InternalRow ${ctx.INPUT_ROW}) {
@@ -63,7 +63,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}")
- val p = compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate]
+ val p = CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate]
(r: InternalRow) => p.eval(r)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index 10bd9c6103..e750ad9c18 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -160,13 +160,13 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
private Object[] references;
private MutableRow mutableRow;
- ${declareMutableStates(ctx)}
- ${declareAddedFunctions(ctx)}
+ ${ctx.declareMutableStates()}
+ ${ctx.declareAddedFunctions()}
public SpecificSafeProjection(Object[] references) {
this.references = references;
mutableRow = new $genericMutableRowType(${expressions.size});
- ${initMutableStates(ctx)}
+ ${ctx.initMutableStates()}
}
public java.lang.Object apply(java.lang.Object _i) {
@@ -179,7 +179,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")
- val c = compile(code)
+ val c = CodeGenerator.compile(code)
c.generate(ctx.references.toArray).asInstanceOf[Projection]
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 1a0565a8eb..61e7469ee4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -338,14 +338,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
class SpecificUnsafeProjection extends ${classOf[UnsafeProjection].getName} {
private Object[] references;
-
- ${declareMutableStates(ctx)}
-
- ${declareAddedFunctions(ctx)}
+ ${ctx.declareMutableStates()}
+ ${ctx.declareAddedFunctions()}
public SpecificUnsafeProjection(Object[] references) {
this.references = references;
- ${initMutableStates(ctx)}
+ ${ctx.initMutableStates()}
}
// Scala.Function1 need this
@@ -362,7 +360,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")
- val c = compile(code)
+ val c = CodeGenerator.compile(code)
c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection]
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
index 8781cc77f4..b1ffbaa3e9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
@@ -196,7 +196,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
logDebug(s"SpecificUnsafeRowJoiner($schema1, $schema2):\n${CodeFormatter.format(code)}")
- val c = compile(code)
+ val c = CodeGenerator.compile(code)
c.generate(Array.empty).asInstanceOf[UnsafeRowJoiner]
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 2a24235a29..1eff2c4dd0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -224,6 +224,7 @@ object CaseWhen {
}
}
+
/**
* Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END".
* When a = b, returns c; when a = d, returns e; else returns f.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 2c12de08f4..493e0aae01 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -351,8 +351,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
val hasher = classOf[Murmur3_x86_32].getName
def hashInt(i: String): ExprCode = inlineValue(s"$hasher.hashInt($i, $seed)")
def hashLong(l: String): ExprCode = inlineValue(s"$hasher.hashLong($l, $seed)")
- def inlineValue(v: String): ExprCode =
- ExprCode(code = "", isNull = "false", value = v)
+ def inlineValue(v: String): ExprCode = ExprCode(code = "", isNull = "false", value = v)
dataType match {
case NullType => inlineValue(seed)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index d0b29aa01f..d74f3ef2ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -452,7 +452,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
* depth `i + 1` is the last child of its own parent node. The depth of the root node is 0, and
* `lastChildren` for the root node should be empty.
*/
- protected def generateTreeString(
+ def generateTreeString(
depth: Int, lastChildren: Seq[Boolean], builder: StringBuilder): StringBuilder = {
if (depth > 0) {
lastChildren.init.foreach { isLast =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 3422d0ead4..95e5fbb119 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql
import java.io.CharArrayWriter
-import java.util.Properties
import scala.language.implicitConversions
import scala.reflect.ClassTag
@@ -39,12 +38,10 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution}
import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
-import org.apache.spark.sql.sources.HadoopFsRelation
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
-
private[sql] object DataFrame {
def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = {
new DataFrame(sqlContext, logicalPlan)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 4e3662724c..4c1eb0b30b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -489,6 +489,13 @@ private[spark] object SQLConf {
isPublic = false,
doc = "This flag should be set to true to enable support for SQL2011 reserved keywords.")
+ val WHOLESTAGE_CODEGEN_ENABLED = booleanConf("spark.sql.codegen.wholeStage",
+ defaultValue = Some(true),
+ doc = "When true, the whole stage (of multiple operators) will be compiled into single java" +
+ " method",
+ isPublic = false)
+
+
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
val EXTERNAL_SORT = "spark.sql.planner.externalSort"
@@ -561,6 +568,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon
private[spark] def nativeView: Boolean = getConf(NATIVE_VIEW)
+ private[spark] def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED)
+
def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE)
private[spark] def subexpressionEliminationEnabled: Boolean =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index a0939adb6d..18ddffe1be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -904,7 +904,8 @@ class SQLContext private[sql](
@transient
protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] {
val batches = Seq(
- Batch("Add exchange", Once, EnsureRequirements(self))
+ Batch("Add exchange", Once, EnsureRequirements(self)),
+ Batch("Whole stage codegen", Once, CollapseCodegenStages(self))
)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
new file mode 100644
index 0000000000..b1bbb1da10
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution;
+
+import scala.collection.Iterator;
+
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+
+/**
+ * An iterator interface used to pull the output from generated function for multiple operators
+ * (whole stage codegen).
+ *
+ * TODO: replaced it by batched columnar format.
+ */
+public class BufferedRowIterator {
+ protected InternalRow currentRow;
+ protected Iterator<InternalRow> input;
+ // used when there is no column in output
+ protected UnsafeRow unsafeRow = new UnsafeRow(0);
+
+ public boolean hasNext() {
+ if (currentRow == null) {
+ processNext();
+ }
+ return currentRow != null;
+ }
+
+ public InternalRow next() {
+ InternalRow r = currentRow;
+ currentRow = null;
+ return r;
+ }
+
+ public void setInput(Iterator<InternalRow> iter) {
+ input = iter;
+ }
+
+ /**
+ * Processes the input until have a row as output (currentRow).
+ *
+ * After it's called, if currentRow is still null, it means no more rows left.
+ */
+ protected void processNext() {
+ if (input.hasNext()) {
+ currentRow = input.next();
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 2355de3d05..75101ea0fc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -97,7 +97,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/** Specifies sort order for each partition requirements on the input data for this operator. */
def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)
-
/**
* Returns the result of this query as an RDD[InternalRow] by delegating to doExecute
* after adding query plan information to created RDDs for visualization.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
new file mode 100644
index 0000000000..c15fabab80
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -0,0 +1,299 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression}
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.rules.Rule
+
+/**
+ * An interface for those physical operators that support codegen.
+ */
+trait CodegenSupport extends SparkPlan {
+
+ /**
+ * Whether this SparkPlan support whole stage codegen or not.
+ */
+ def supportCodegen: Boolean = true
+
+ /**
+ * Which SparkPlan is calling produce() of this one. It's itself for the first SparkPlan.
+ */
+ private var parent: CodegenSupport = null
+
+ /**
+ * Returns an input RDD of InternalRow and Java source code to process them.
+ */
+ def produce(ctx: CodegenContext, parent: CodegenSupport): (RDD[InternalRow], String) = {
+ this.parent = parent
+ doProduce(ctx)
+ }
+
+ /**
+ * Generate the Java source code to process, should be overrided by subclass to support codegen.
+ *
+ * doProduce() usually generate the framework, for example, aggregation could generate this:
+ *
+ * if (!initialized) {
+ * # create a hash map, then build the aggregation hash map
+ * # call child.produce()
+ * initialized = true;
+ * }
+ * while (hashmap.hasNext()) {
+ * row = hashmap.next();
+ * # build the aggregation results
+ * # create varialbles for results
+ * # call consume(), wich will call parent.doConsume()
+ * }
+ */
+ protected def doProduce(ctx: CodegenContext): (RDD[InternalRow], String)
+
+ /**
+ * Consume the columns generated from current SparkPlan, call it's parent or create an iterator.
+ */
+ protected def consume(ctx: CodegenContext, columns: Seq[ExprCode]): String = {
+ assert(columns.length == output.length)
+ parent.doConsume(ctx, this, columns)
+ }
+
+
+ /**
+ * Generate the Java source code to process the rows from child SparkPlan.
+ *
+ * This should be override by subclass to support codegen.
+ *
+ * For example, Filter will generate the code like this:
+ *
+ * # code to evaluate the predicate expression, result is isNull1 and value2
+ * if (isNull1 || value2) {
+ * # call consume(), which will call parent.doConsume()
+ * }
+ */
+ def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String
+}
+
+
+/**
+ * InputAdapter is used to hide a SparkPlan from a subtree that support codegen.
+ *
+ * This is the leaf node of a tree with WholeStageCodegen, is used to generate code that consumes
+ * an RDD iterator of InternalRow.
+ */
+case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
+
+ override def output: Seq[Attribute] = child.output
+
+ override def supportCodegen: Boolean = true
+
+ override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+ val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
+ val row = ctx.freshName("row")
+ ctx.INPUT_ROW = row
+ ctx.currentVars = null
+ val columns = exprs.map(_.gen(ctx))
+ val code = s"""
+ | while (input.hasNext()) {
+ | InternalRow $row = (InternalRow) input.next();
+ | ${columns.map(_.code).mkString("\n")}
+ | ${consume(ctx, columns)}
+ | }
+ """.stripMargin
+ (child.execute(), code)
+ }
+
+ def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
+ throw new UnsupportedOperationException
+ }
+
+ override def doExecute(): RDD[InternalRow] = {
+ throw new UnsupportedOperationException
+ }
+
+ override def simpleString: String = "INPUT"
+}
+
+/**
+ * WholeStageCodegen compile a subtree of plans that support codegen together into single Java
+ * function.
+ *
+ * Here is the call graph of to generate Java source (plan A support codegen, but plan B does not):
+ *
+ * WholeStageCodegen Plan A FakeInput Plan B
+ * =========================================================================
+ *
+ * -> execute()
+ * |
+ * doExecute() --------> produce()
+ * |
+ * doProduce() -------> produce()
+ * |
+ * doProduce() ---> execute()
+ * |
+ * consume()
+ * doConsume() ------------|
+ * |
+ * doConsume() <----- consume()
+ *
+ * SparkPlan A should override doProduce() and doConsume().
+ *
+ * doCodeGen() will create a CodeGenContext, which will hold a list of variables for input,
+ * used to generated code for BoundReference.
+ */
+case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
+ extends SparkPlan with CodegenSupport {
+
+ override def output: Seq[Attribute] = plan.output
+
+ override def doExecute(): RDD[InternalRow] = {
+ val ctx = new CodegenContext
+ val (rdd, code) = plan.produce(ctx, this)
+ val references = ctx.references.toArray
+ val source = s"""
+ public Object generate(Object[] references) {
+ return new GeneratedIterator(references);
+ }
+
+ class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
+
+ private Object[] references;
+ ${ctx.declareMutableStates()}
+
+ public GeneratedIterator(Object[] references) {
+ this.references = references;
+ ${ctx.initMutableStates()}
+ }
+
+ protected void processNext() {
+ $code
+ }
+ }
+ """
+ // try to compile, helpful for debug
+ // println(s"${CodeFormatter.format(source)}")
+ CodeGenerator.compile(source)
+
+ rdd.mapPartitions { iter =>
+ val clazz = CodeGenerator.compile(source)
+ val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
+ buffer.setInput(iter)
+ new Iterator[InternalRow] {
+ override def hasNext: Boolean = buffer.hasNext
+ override def next: InternalRow = buffer.next()
+ }
+ }
+ }
+
+ override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+ throw new UnsupportedOperationException
+ }
+
+ override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
+ if (input.nonEmpty) {
+ val colExprs = output.zipWithIndex.map { case (attr, i) =>
+ BoundReference(i, attr.dataType, attr.nullable)
+ }
+ // generate the code to create a UnsafeRow
+ ctx.currentVars = input
+ val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
+ s"""
+ | ${code.code.trim}
+ | currentRow = ${code.value};
+ | return;
+ """.stripMargin
+ } else {
+ // There is no columns
+ s"""
+ | currentRow = unsafeRow;
+ | return;
+ """.stripMargin
+ }
+ }
+
+ override def generateTreeString(
+ depth: Int,
+ lastChildren: Seq[Boolean],
+ builder: StringBuilder): StringBuilder = {
+ if (depth > 0) {
+ lastChildren.init.foreach { isLast =>
+ val prefixFragment = if (isLast) " " else ": "
+ builder.append(prefixFragment)
+ }
+
+ val branch = if (lastChildren.last) "+- " else ":- "
+ builder.append(branch)
+ }
+
+ builder.append(simpleString)
+ builder.append("\n")
+
+ plan.generateTreeString(depth + 1, lastChildren :+children.isEmpty :+ true, builder)
+ if (children.nonEmpty) {
+ children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
+ children.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
+ }
+
+ builder
+ }
+
+ override def simpleString: String = "WholeStageCodegen"
+}
+
+
+/**
+ * Find the chained plans that support codegen, collapse them together as WholeStageCodegen.
+ */
+private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Rule[SparkPlan] {
+
+ private def supportCodegen(plan: SparkPlan): Boolean = plan match {
+ case plan: CodegenSupport if plan.supportCodegen =>
+ // Non-leaf with CodegenFallback does not work with whole stage codegen
+ val willFallback = plan.expressions.exists(
+ _.find(e => e.isInstanceOf[CodegenFallback] && !e.isInstanceOf[LeafExpression]).isDefined
+ )
+ // the generated code will be huge if there are too many columns
+ val haveManyColumns = plan.output.length > 200
+ !willFallback && !haveManyColumns
+ case _ => false
+ }
+
+ def apply(plan: SparkPlan): SparkPlan = {
+ if (sqlContext.conf.wholeStageEnabled) {
+ plan.transform {
+ case plan: CodegenSupport if supportCodegen(plan) &&
+ // Whole stage codegen is only useful when there are at least two levels of operators that
+ // support it (save at least one projection/iterator).
+ plan.children.exists(supportCodegen) =>
+
+ var inputs = ArrayBuffer[SparkPlan]()
+ val combined = plan.transform {
+ case p if !supportCodegen(p) =>
+ inputs += p
+ InputAdapter(p)
+ }.asInstanceOf[CodegenSupport]
+ WholeStageCodegen(combined, inputs)
+ }
+ } else {
+ plan
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 92c9a56131..9e2e0357c6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -22,19 +22,37 @@ import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.LongType
import org.apache.spark.util.MutablePair
import org.apache.spark.util.random.PoissonSampler
-case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode {
+case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
+ extends UnaryNode with CodegenSupport {
override private[sql] lazy val metrics = Map(
"numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows"))
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
+ protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+ child.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
+ val exprs = projectList.map(x =>
+ ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output)))
+ ctx.currentVars = input
+ val output = exprs.map(_.gen(ctx))
+ s"""
+ | ${output.map(_.code).mkString("\n")}
+ |
+ | ${consume(ctx, output)}
+ """.stripMargin
+ }
+
protected override def doExecute(): RDD[InternalRow] = {
val numRows = longMetric("numRows")
child.execute().mapPartitionsInternal { iter =>
@@ -51,13 +69,30 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends
}
-case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
+case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode with CodegenSupport {
override def output: Seq[Attribute] = child.output
private[sql] override lazy val metrics = Map(
"numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+ child.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
+ val expr = ExpressionCanonicalizer.execute(
+ BindReferences.bindReference(condition, child.output))
+ ctx.currentVars = input
+ val eval = expr.gen(ctx)
+ s"""
+ | ${eval.code}
+ | if (!${eval.isNull} && ${eval.value}) {
+ | ${consume(ctx, ctx.currentVars)}
+ | }
+ """.stripMargin
+ }
+
protected override def doExecute(): RDD[InternalRow] = {
val numInputRows = longMetric("numInputRows")
val numOutputRows = longMetric("numOutputRows")
@@ -116,7 +151,80 @@ case class Range(
numSlices: Int,
numElements: BigInt,
output: Seq[Attribute])
- extends LeafNode {
+ extends LeafNode with CodegenSupport {
+
+ protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+ val initTerm = ctx.freshName("range_initRange")
+ ctx.addMutableState("boolean", initTerm, s"$initTerm = false;")
+ val partitionEnd = ctx.freshName("range_partitionEnd")
+ ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;")
+ val number = ctx.freshName("range_number")
+ ctx.addMutableState("long", number, s"$number = 0L;")
+ val overflow = ctx.freshName("range_overflow")
+ ctx.addMutableState("boolean", overflow, s"$overflow = false;")
+
+ val value = ctx.freshName("range_value")
+ val ev = ExprCode("", "false", value)
+ val BigInt = classOf[java.math.BigInteger].getName
+ val checkEnd = if (step > 0) {
+ s"$number < $partitionEnd"
+ } else {
+ s"$number > $partitionEnd"
+ }
+
+ val rdd = sqlContext.sparkContext.parallelize(0 until numSlices, numSlices)
+ .map(i => InternalRow(i))
+
+ val code = s"""
+ | // initialize Range
+ | if (!$initTerm) {
+ | $initTerm = true;
+ | if (input.hasNext()) {
+ | $BigInt index = $BigInt.valueOf(((InternalRow) input.next()).getInt(0));
+ | $BigInt numSlice = $BigInt.valueOf(${numSlices}L);
+ | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L);
+ | $BigInt step = $BigInt.valueOf(${step}L);
+ | $BigInt start = $BigInt.valueOf(${start}L);
+ |
+ | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
+ | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
+ | $number = Long.MAX_VALUE;
+ | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
+ | $number = Long.MIN_VALUE;
+ | } else {
+ | $number = st.longValue();
+ | }
+ |
+ | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
+ | .multiply(step).add(start);
+ | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
+ | $partitionEnd = Long.MAX_VALUE;
+ | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
+ | $partitionEnd = Long.MIN_VALUE;
+ | } else {
+ | $partitionEnd = end.longValue();
+ | }
+ | } else {
+ | return;
+ | }
+ | }
+ |
+ | while (!$overflow && $checkEnd) {
+ | long $value = $number;
+ | $number += ${step}L;
+ | if ($number < $value ^ ${step}L < 0) {
+ | $overflow = true;
+ | }
+ | ${consume(ctx, Seq(ev))}
+ | }
+ """.stripMargin
+
+ (rdd, code)
+ }
+
+ def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
+ throw new UnsupportedOperationException
+ }
protected override def doExecute(): RDD[InternalRow] = {
sqlContext
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
index 7888e34e8a..72eb1f6cf0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
@@ -143,14 +143,14 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
private DataType[] columnTypes = null;
private int[] columnIndexes = null;
- ${declareMutableStates(ctx)}
+ ${ctx.declareMutableStates()}
public SpecificColumnarIterator() {
this.nativeOrder = ByteOrder.nativeOrder();
this.buffers = new byte[${columnTypes.length}][];
this.mutableRow = new MutableUnsafeRow(rowWriter);
- ${initMutableStates(ctx)}
+ ${ctx.initMutableStates()}
}
public void initialize(Iterator input, DataType[] columnTypes, int[] columnIndexes) {
@@ -190,6 +190,6 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
logDebug(s"Generated ColumnarIterator: ${CodeFormatter.format(code)}")
- compile(code).generate(Array.empty).asInstanceOf[ColumnarIterator]
+ CodeGenerator.compile(code).generate(Array.empty).asInstanceOf[ColumnarIterator]
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 89b9a68768..e8d0678989 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -36,12 +36,12 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
import testImplicits._
def rddIdOf(tableName: String): Int = {
- val executedPlan = sqlContext.table(tableName).queryExecution.executedPlan
- executedPlan.collect {
+ val plan = sqlContext.table(tableName).queryExecution.sparkPlan
+ plan.collect {
case InMemoryColumnarTableScan(_, _, relation) =>
relation.cachedColumnBuffers.id
case _ =>
- fail(s"Table $tableName is not cached\n" + executedPlan)
+ fail(s"Table $tableName is not cached\n" + plan)
}.head
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index eb4efcd1d4..b349bb6dc9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -629,7 +629,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
}
def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = {
- val projects = df.queryExecution.executedPlan.collect {
+ val projects = df.queryExecution.sparkPlan.collect {
case tungstenProject: Project => tungstenProject
}
assert(projects.size === expectedNumProjects)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index 39a65413bd..c17be8ace9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -123,15 +123,15 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")
// equijoin - should be converted into broadcast join
- val plan1 = df1.join(broadcast(df2), "key").queryExecution.executedPlan
+ val plan1 = df1.join(broadcast(df2), "key").queryExecution.sparkPlan
assert(plan1.collect { case p: BroadcastHashJoin => p }.size === 1)
// no join key -- should not be a broadcast join
- val plan2 = df1.join(broadcast(df2)).queryExecution.executedPlan
+ val plan2 = df1.join(broadcast(df2)).queryExecution.sparkPlan
assert(plan2.collect { case p: BroadcastHashJoin => p }.size === 0)
// planner should not crash without a join
- broadcast(df1).queryExecution.executedPlan
+ broadcast(df1).queryExecution.sparkPlan
// SPARK-12275: no physical plan for BroadcastHint in some condition
withTempPath { path =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 75e81b9c91..bdb9421cc1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -247,7 +247,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
private def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = {
val df = sql(sqlText)
// First, check if we have GeneratedAggregate.
- val hasGeneratedAgg = df.queryExecution.executedPlan
+ val hasGeneratedAgg = df.queryExecution.sparkPlan
.collect { case _: aggregate.TungstenAggregate => true }
.nonEmpty
if (!hasGeneratedAgg) {
@@ -792,11 +792,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("SPARK-11111 null-safe join should not use cartesian product") {
val df = sql("select count(*) from testData a join testData b on (a.key <=> b.key)")
- val cp = df.queryExecution.executedPlan.collect {
+ val cp = df.queryExecution.sparkPlan.collect {
case cp: CartesianProduct => cp
}
assert(cp.isEmpty, "should not use CartesianProduct for null-safe join")
- val smj = df.queryExecution.executedPlan.collect {
+ val smj = df.queryExecution.sparkPlan.collect {
case smj: SortMergeJoin => smj
}
assert(smj.size > 0, "should use SortMergeJoin")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
new file mode 100644
index 0000000000..788b04fcf8
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.util.Benchmark
+
+/**
+ * Benchmark to measure whole stage codegen performance.
+ * To run this:
+ * build/sbt "sql/test-only *BenchmarkWholeStageCodegen"
+ */
+class BenchmarkWholeStageCodegen extends SparkFunSuite {
+ def testWholeStage(values: Int): Unit = {
+ val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark")
+ val sc = SparkContext.getOrCreate(conf)
+ val sqlContext = SQLContext.getOrCreate(sc)
+
+ val benchmark = new Benchmark("Single Int Column Scan", values)
+
+ benchmark.addCase("Without whole stage codegen") { iter =>
+ sqlContext.setConf("spark.sql.codegen.wholeStage", "false")
+ sqlContext.range(values).filter("(id & 1) = 1").count()
+ }
+
+ benchmark.addCase("With whole stage codegen") { iter =>
+ sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
+ sqlContext.range(values).filter("(id & 1) = 1").count()
+ }
+
+ /*
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate
+ -------------------------------------------------------------------------
+ Without whole stage codegen 6725.52 31.18 1.00 X
+ With whole stage codegen 2233.05 93.91 3.01 X
+ */
+ benchmark.run()
+ }
+
+ ignore("benchmark") {
+ testWholeStage(1024 * 1024 * 200)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 03a1b8e11d..49feeaf17d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -94,7 +94,7 @@ class PlannerSuite extends SharedSQLContext {
"""
|SELECT l.a, l.b
|FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key)
- """.stripMargin).queryExecution.executedPlan
+ """.stripMargin).queryExecution.sparkPlan
val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join }
val sortMergeJoins = planned.collect { case join: SortMergeJoin => join }
@@ -147,7 +147,7 @@ class PlannerSuite extends SharedSQLContext {
val a = testData.as("a")
val b = sqlContext.table("tiny").as("b")
- val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan
+ val planned = a.join(b, $"a.key" === $"b.key").queryExecution.sparkPlan
val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join }
val sortMergeJoins = planned.collect { case join: SortMergeJoin => join }
@@ -168,7 +168,7 @@ class PlannerSuite extends SharedSQLContext {
sqlContext.registerDataFrameAsTable(df, "testPushed")
withTempTable("testPushed") {
- val exp = sql("select * from testPushed where key = 15").queryExecution.executedPlan
+ val exp = sql("select * from testPushed where key = 15").queryExecution.sparkPlan
assert(exp.toString.contains("PushedFilters: [EqualTo(key,15)]"))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
new file mode 100644
index 0000000000..c54fc6ba2d
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.test.SharedSQLContext
+
+class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
+
+ test("range/filter should be combined") {
+ val df = sqlContext.range(10).filter("id = 1").selectExpr("id + 1")
+ val plan = df.queryExecution.executedPlan
+ assert(plan.find(_.isInstanceOf[WholeStageCodegen]).isDefined)
+
+ checkThatPlansAgree(
+ sqlContext.range(100),
+ (p: SparkPlan) =>
+ WholeStageCodegen(Filter('a == 1, InputAdapter(p)), Seq()),
+ (p: SparkPlan) => Filter('a == 1, p),
+ sortAnswers = false
+ )
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
index 25afed25c8..6e21d5a061 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
@@ -31,7 +31,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
setupTestData()
test("simple columnar query") {
- val plan = sqlContext.executePlan(testData.logicalPlan).executedPlan
+ val plan = sqlContext.executePlan(testData.logicalPlan).sparkPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
checkAnswer(scan, testData.collect().toSeq)
@@ -48,7 +48,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
}
test("projection") {
- val plan = sqlContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan
+ val plan = sqlContext.executePlan(testData.select('value, 'key).logicalPlan).sparkPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
checkAnswer(scan, testData.collect().map {
@@ -57,7 +57,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
}
test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") {
- val plan = sqlContext.executePlan(testData.logicalPlan).executedPlan
+ val plan = sqlContext.executePlan(testData.logicalPlan).sparkPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
checkAnswer(scan, testData.collect().toSeq)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
index d762f7bfe9..647a7e9a4e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
@@ -114,7 +114,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext {
df.collect().map(_(0)).toArray
}
- val (readPartitions, readBatches) = df.queryExecution.executedPlan.collect {
+ val (readPartitions, readBatches) = df.queryExecution.sparkPlan.collect {
case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value)
}.head
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 58581d71e1..aee8e84db5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -62,7 +62,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
// Comparison at the end is for broadcast left semi join
val joinExpression = df1("key") === df2("key") && df1("value") > df2("value")
val df3 = df1.join(broadcast(df2), joinExpression, joinType)
- val plan = df3.queryExecution.executedPlan
+ val plan = df3.queryExecution.sparkPlan
assert(plan.collect { case p: T => p }.size === 1)
plan.executeCollect()
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
index 9b37dd1103..11863caffe 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
@@ -30,12 +30,12 @@ class CachedTableSuite extends QueryTest with TestHiveSingleton {
import hiveContext._
def rddIdOf(tableName: String): Int = {
- val executedPlan = table(tableName).queryExecution.executedPlan
- executedPlan.collect {
+ val plan = table(tableName).queryExecution.sparkPlan
+ plan.collect {
case InMemoryColumnarTableScan(_, _, relation) =>
relation.cachedColumnBuffers.id
case _ =>
- fail(s"Table $tableName is not cached\n" + executedPlan)
+ fail(s"Table $tableName is not cached\n" + plan)
}.head
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
index fd3339a66b..2e0a8698e6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
@@ -485,7 +485,7 @@ abstract class HiveComparisonTest
val executions = queryList.map(new TestHive.QueryExecution(_))
executions.foreach(_.toRdd)
val tablesGenerated = queryList.zip(executions).flatMap {
- case (q, e) => e.executedPlan.collect {
+ case (q, e) => e.sparkPlan.collect {
case i: InsertIntoHiveTable if tablesRead contains i.table.tableName =>
(q, e, i)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala
index 5bd323ea09..d2f91861ff 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala
@@ -43,7 +43,7 @@ class HiveTypeCoercionSuite extends HiveComparisonTest {
test("[SPARK-2210] boolean cast on boolean value should be removed") {
val q = "select cast(cast(key=0 as boolean) as boolean) from src"
- val project = TestHive.sql(q).queryExecution.executedPlan.collect {
+ val project = TestHive.sql(q).queryExecution.sparkPlan.collect {
case e: Project => e
}.head
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala
index 210d566745..b91248bfb3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala
@@ -144,7 +144,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter {
expectedScannedColumns: Seq[String],
expectedPartValues: Seq[Seq[String]]): Unit = {
test(s"$testCaseName - pruning test") {
- val plan = new TestHive.QueryExecution(sql).executedPlan
+ val plan = new TestHive.QueryExecution(sql).sparkPlan
val actualOutputColumns = plan.output.map(_.name)
val (actualScannedColumns, actualPartValues) = plan.collect {
case p @ HiveTableScan(columns, relation, _) =>
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
index ed544c6380..c997453803 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
@@ -190,11 +190,11 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {
test(s"conversion is working") {
assert(
- sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect {
+ sql("SELECT * FROM normal_parquet").queryExecution.sparkPlan.collect {
case _: HiveTableScan => true
}.isEmpty)
assert(
- sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect {
+ sql("SELECT * FROM normal_parquet").queryExecution.sparkPlan.collect {
case _: PhysicalRDD => true
}.nonEmpty)
}
@@ -305,7 +305,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {
""".stripMargin)
val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt")
- df.queryExecution.executedPlan match {
+ df.queryExecution.sparkPlan match {
case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation, _, _)) => // OK
case o => fail("test_insert_parquet should be converted to a " +
s"${classOf[ParquetRelation].getCanonicalName} and " +
@@ -335,7 +335,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {
""".stripMargin)
val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array")
- df.queryExecution.executedPlan match {
+ df.queryExecution.sparkPlan match {
case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation, _, _)) => // OK
case o => fail("test_insert_parquet should be converted to a " +
s"${classOf[ParquetRelation].getCanonicalName} and " +
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala
index e866493ee6..ba2a483bba 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala
@@ -149,7 +149,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
sqlContext.range(2).select('id as 'a, 'id as 'b).write.partitionBy("b").parquet(path)
val df = sqlContext.read.parquet(path).filter('a === 0).select('b)
- val physicalPlan = df.queryExecution.executedPlan
+ val physicalPlan = df.queryExecution.sparkPlan
assert(physicalPlan.collect { case p: execution.Project => p }.length === 1)
assert(physicalPlan.collect { case p: execution.Filter => p }.length === 1)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
index 058c101eeb..9ab3e11609 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
@@ -156,9 +156,9 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat
test(s"pruning and filtering: df.select(${projections.mkString(", ")}).where($filter)") {
val df = partitionedDF.where(filter).select(projections: _*)
val queryExecution = df.queryExecution
- val executedPlan = queryExecution.executedPlan
+ val sparkPlan = queryExecution.sparkPlan
- val rawScan = executedPlan.collect {
+ val rawScan = sparkPlan.collect {
case p: PhysicalRDD => p
} match {
case Seq(scan) => scan
@@ -177,7 +177,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat
assert(requiredColumns === SimpleTextRelation.requiredColumns)
val nonPushedFilters = {
- val boundFilters = executedPlan.collect {
+ val boundFilters = sparkPlan.collect {
case f: execution.Filter => f
} match {
case Nil => Nil