aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-08-05 09:01:45 -0700
committerYin Huai <yhuai@databricks.com>2015-08-05 09:01:45 -0700
commit23d982204bb9ef74d3b788a32ce6608116968719 (patch)
tree4faef48a8551c8da604463613c2c1b6fc9762d30 /sql
parent34dcf10104460816382908b2b8eeb6c925e862bf (diff)
downloadspark-23d982204bb9ef74d3b788a32ce6608116968719.tar.gz
spark-23d982204bb9ef74d3b788a32ce6608116968719.tar.bz2
spark-23d982204bb9ef74d3b788a32ce6608116968719.zip
[SPARK-9141] [SQL] Remove project collapsing from DataFrame API
Currently we collapse successive projections that are added by `withColumn`. However, this optimization violates the constraint that adding nodes to a plan will never change its analyzed form and thus breaks caching. Instead of doing early optimization, in this PR I just fix some low-hanging slowness in the analyzer. In particular, I add a mechanism for skipping already analyzed subplans, `resolveOperators` and `resolveExpression`. Since trees are generally immutable after construction, it's safe to annotate a plan as already analyzed as any transformation will create a new tree with this bit no longer set. Together these result in a faster analyzer than before, even with added timing instrumentation. ``` Original Code [info] 3430ms [info] 2205ms [info] 1973ms [info] 1982ms [info] 1916ms Without Project Collapsing in DataFrame [info] 44610ms [info] 45977ms [info] 46423ms [info] 46306ms [info] 54723ms With analyzer optimizations [info] 6394ms [info] 4630ms [info] 4388ms [info] 4093ms [info] 4113ms With resolveOperators [info] 2495ms [info] 1380ms [info] 1685ms [info] 1414ms [info] 1240ms ``` Author: Michael Armbrust <michael@databricks.com> Closes #7920 from marmbrus/withColumnCache and squashes the following commits: 2145031 [Michael Armbrust] fix hive udfs tests 5a5a525 [Michael Armbrust] remove wrong comment 7a507d5 [Michael Armbrust] style b59d710 [Michael Armbrust] revert small change 1fa5949 [Michael Armbrust] move logic into LogicalPlan, add tests 0e2cb43 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into withColumnCache c926e24 [Michael Armbrust] naming e593a2d [Michael Armbrust] style f5a929e [Michael Armbrust] [SPARK-9141][SQL] Remove project collapsing from DataFrame API 38b1c83 [Michael Armbrust] WIP
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala28
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala30
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala51
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala22
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala64
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala72
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala12
-rw-r--r--sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala5
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala5
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala3
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala8
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala2
18 files changed, 234 insertions, 106 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index ca17f3e3d0..6de31f42dd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -90,7 +90,7 @@ class Analyzer(
*/
object CTESubstitution extends Rule[LogicalPlan] {
// TODO allow subquery to define CTE
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case With(child, relations) => substituteCTE(child, relations)
case other => other
}
@@ -116,7 +116,7 @@ class Analyzer(
* Substitute child plan with WindowSpecDefinitions.
*/
object WindowsSubstitution extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// Lookup WindowSpecDefinitions. This rule works with unresolved children.
case WithWindowDefinition(windowDefinitions, child) =>
child.transform {
@@ -140,7 +140,7 @@ class Analyzer(
object ResolveAliases extends Rule[LogicalPlan] {
private def assignAliases(exprs: Seq[NamedExpression]) = {
// The `UnresolvedAlias`s will appear only at root of a expression tree, we don't need
- // to transform down the whole tree.
+ // to resolveOperator down the whole tree.
exprs.zipWithIndex.map {
case (u @ UnresolvedAlias(child), i) =>
child match {
@@ -156,7 +156,7 @@ class Analyzer(
}
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case Aggregate(groups, aggs, child)
if child.resolved && aggs.exists(_.isInstanceOf[UnresolvedAlias]) =>
Aggregate(groups, assignAliases(aggs), child)
@@ -198,7 +198,7 @@ class Analyzer(
Seq.tabulate(1 << c.groupByExprs.length)(i => i)
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case a if !a.childrenResolved => a // be sure all of the children are resolved.
case a: Cube =>
GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations)
@@ -261,7 +261,7 @@ class Analyzer(
}
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _, _) =>
i.copy(table = EliminateSubQueries(getTable(u)))
case u: UnresolvedRelation =>
@@ -274,7 +274,7 @@ class Analyzer(
* a logical plan node's children.
*/
object ResolveReferences extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p: LogicalPlan if !p.childrenResolved => p
// If the projection list contains Stars, expand it.
@@ -444,7 +444,7 @@ class Analyzer(
* remove these attributes after sorting.
*/
object ResolveSortReferences extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case s @ Sort(ordering, global, p @ Project(projectList, child))
if !s.resolved && p.resolved =>
val (newOrdering, missing) = resolveAndFindMissing(ordering, p, child)
@@ -519,7 +519,7 @@ class Analyzer(
* Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s.
*/
object ResolveFunctions extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case q: LogicalPlan =>
q transformExpressions {
case u @ UnresolvedFunction(name, children, isDistinct) =>
@@ -551,7 +551,7 @@ class Analyzer(
* Turns projections that contain aggregate expressions into aggregations.
*/
object GlobalAggregates extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case Project(projectList, child) if containsAggregates(projectList) =>
Aggregate(Nil, projectList, child)
}
@@ -571,7 +571,7 @@ class Analyzer(
* aggregates and then projects them away above the filter.
*/
object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _))
if aggregate.resolved && containsAggregate(havingCondition) =>
@@ -601,7 +601,7 @@ class Analyzer(
* [[AnalysisException]] is throw.
*/
object ResolveGenerate extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p: Generate if !p.child.resolved || !p.generator.resolved => p
case g: Generate if !g.resolved =>
g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name)))
@@ -872,6 +872,8 @@ class Analyzer(
// We have to use transformDown at here to make sure the rule of
// "Aggregate with Having clause" will be triggered.
def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
+
+
// Aggregate with Having clause. This rule works with an unresolved Aggregate because
// a resolved Aggregate will not have Window Functions.
case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
@@ -927,7 +929,7 @@ class Analyzer(
* put them into an inner Project and finally project them away at the outer Project.
*/
object PullOutNondeterministic extends Rule[LogicalPlan] {
- override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p: Project => p
case f: Filter => f
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 187b238045..39f554c137 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -47,6 +47,7 @@ trait CheckAnalysis {
// We transform up and order the rules so as to catch the first possible failure instead
// of the result of cascading resolution failures.
plan.foreachUp {
+ case p if p.analyzed => // Skip already analyzed sub-plans
case operator: LogicalPlan =>
operator transformExpressionsUp {
@@ -179,5 +180,7 @@ trait CheckAnalysis {
}
}
extendedCheckRules.foreach(_(plan))
+
+ plan.foreach(_.setAnalyzed())
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 490f3dc07b..970f3c8282 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -144,7 +144,8 @@ object HiveTypeCoercion {
* instances higher in the query tree.
*/
object PropagateTypes extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+
// No propagation required for leaf nodes.
case q: LogicalPlan if q.children.isEmpty => q
@@ -225,7 +226,9 @@ object HiveTypeCoercion {
}
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case p if p.analyzed => p
+
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
val (newLeft, newRight) = widenOutputTypes(u.nodeName, left, right)
Union(newLeft, newRight)
@@ -242,7 +245,7 @@ object HiveTypeCoercion {
* Promotes strings that appear in arithmetic expressions.
*/
object PromoteStrings extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
@@ -305,7 +308,7 @@ object HiveTypeCoercion {
* Convert all expressions in in() list to the left operator type
*/
object InConversion extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
@@ -372,7 +375,8 @@ object HiveTypeCoercion {
ChangeDecimalPrecision(Cast(e, dataType))
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+
// fix decimal precision for expressions
case q => q.transformExpressions {
// Skip nodes whose children have not been resolved yet
@@ -466,7 +470,7 @@ object HiveTypeCoercion {
))
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
@@ -508,7 +512,7 @@ object HiveTypeCoercion {
* truncated version of this number.
*/
object StringToIntegralCasts extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
@@ -521,7 +525,7 @@ object HiveTypeCoercion {
* This ensure that the types for various functions are as expected.
*/
object FunctionArgumentConversion extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
@@ -575,7 +579,7 @@ object HiveTypeCoercion {
* converted to fractional types.
*/
object Division extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who has not been resolved yet,
// as this is an extra rule which should be applied at last.
case e if !e.resolved => e
@@ -592,7 +596,7 @@ object HiveTypeCoercion {
* Coerces the type of different branches of a CASE WHEN statement to a common type.
*/
object CaseWhenCoercion extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual =>
logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}")
val maybeCommonType = findTightestCommonTypeAndPromoteToString(c.valueTypes)
@@ -628,7 +632,7 @@ object HiveTypeCoercion {
* Coerces the type of different branches of If statement to a common type.
*/
object IfCoercion extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Find tightest common type for If, if the true value and false value have different types.
case i @ If(pred, left, right) if left.dataType != right.dataType =>
findTightestCommonTypeToString(left.dataType, right.dataType).map { widestType =>
@@ -652,7 +656,7 @@ object HiveTypeCoercion {
private val acceptedTypes = Seq(DateType, TimestampType, StringType)
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
@@ -669,7 +673,7 @@ object HiveTypeCoercion {
* Casts types according to the expected input types for [[Expression]]s.
*/
object ImplicitTypeCasts extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index c610f70d38..55286f9f2f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.plans
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, VirtualColumn}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types.{DataType, StructType}
@@ -92,7 +93,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
val newArgs = productIterator.map(recursiveTransform).toArray
- if (changed) makeCopy(newArgs) else this
+ if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this
}
/**
@@ -124,7 +125,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
val newArgs = productIterator.map(recursiveTransform).toArray
- if (changed) makeCopy(newArgs) else this
+ if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this
}
/** Returns the result of running [[transformExpressions]] on this node
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index bedeaf06ad..9b52f02009 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -22,11 +22,60 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.trees.TreeNode
+import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode}
abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
+ private var _analyzed: Boolean = false
+
+ /**
+ * Marks this plan as already analyzed. This should only be called by CheckAnalysis.
+ */
+ private[catalyst] def setAnalyzed(): Unit = { _analyzed = true }
+
+ /**
+ * Returns true if this node and its children have already been gone through analysis and
+ * verification. Note that this is only an optimization used to avoid analyzing trees that
+ * have already been analyzed, and can be reset by transformations.
+ */
+ def analyzed: Boolean = _analyzed
+
+ /**
+ * Returns a copy of this node where `rule` has been recursively applied first to all of its
+ * children and then itself (post-order). When `rule` does not apply to a given node, it is left
+ * unchanged. This function is similar to `transformUp`, but skips sub-trees that have already
+ * been marked as analyzed.
+ *
+ * @param rule the function use to transform this nodes children
+ */
+ def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = {
+ if (!analyzed) {
+ val afterRuleOnChildren = transformChildren(rule, (t, r) => t.resolveOperators(r))
+ if (this fastEquals afterRuleOnChildren) {
+ CurrentOrigin.withOrigin(origin) {
+ rule.applyOrElse(this, identity[LogicalPlan])
+ }
+ } else {
+ CurrentOrigin.withOrigin(origin) {
+ rule.applyOrElse(afterRuleOnChildren, identity[LogicalPlan])
+ }
+ }
+ } else {
+ this
+ }
+ }
+
+ /**
+ * Recursively transforms the expressions of a tree, skipping nodes that have already
+ * been analyzed.
+ */
+ def resolveExpressions(r: PartialFunction[Expression, Expression]): LogicalPlan = {
+ this resolveOperators {
+ case p => p.transformExpressions(r)
+ }
+ }
+
/**
* Computes [[Statistics]] for this plan. The default implementation assumes the output
* cardinality is the product of of all child plan's cardinality, i.e. applies in the case
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
index 3f9858b0c4..8b824511a7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
@@ -21,6 +21,23 @@ import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.util.sideBySide
+import scala.collection.mutable
+
+object RuleExecutor {
+ protected val timeMap = new mutable.HashMap[String, Long].withDefault(_ => 0)
+
+ /** Resets statistics about time spent running specific rules */
+ def resetTime(): Unit = timeMap.clear()
+
+ /** Dump statistics about time spent running specific rules. */
+ def dumpTimeSpent(): String = {
+ val maxSize = timeMap.keys.map(_.toString.length).max
+ timeMap.toSeq.sortBy(_._2).reverseMap { case (k, v) =>
+ s"${k.padTo(maxSize, " ").mkString} $v"
+ }.mkString("\n")
+ }
+}
+
abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
/**
@@ -41,6 +58,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
/** Defines a sequence of rule batches, to be overridden by the implementation. */
protected val batches: Seq[Batch]
+
/**
* Executes the batches of rules defined by the subclass. The batches are executed serially
* using the defined execution strategy. Within each batch, rules are also executed serially.
@@ -58,7 +76,11 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
while (continue) {
curPlan = batch.rules.foldLeft(curPlan) {
case (plan, rule) =>
+ val startTime = System.nanoTime()
val result = rule(plan)
+ val runTime = System.nanoTime() - startTime
+ RuleExecutor.timeMap(rule.ruleName) = RuleExecutor.timeMap(rule.ruleName) + runTime
+
if (!result.fastEquals(plan)) {
logTrace(
s"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 122e9fc5ed..7971e25188 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -149,7 +149,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
/**
* Returns a copy of this node where `f` has been applied to all the nodes children.
*/
- def mapChildren(f: BaseType => BaseType): this.type = {
+ def mapChildren(f: BaseType => BaseType): BaseType = {
var changed = false
val newArgs = productIterator.map {
case arg: TreeNode[_] if containsChild(arg) =>
@@ -170,7 +170,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
* Returns a copy of this node with the children replaced.
* TODO: Validate somewhere (in debug mode?) that children are ordered correctly.
*/
- def withNewChildren(newChildren: Seq[BaseType]): this.type = {
+ def withNewChildren(newChildren: Seq[BaseType]): BaseType = {
assert(newChildren.size == children.size, "Incorrect number of children")
var changed = false
val remainingNewChildren = newChildren.toBuffer
@@ -229,9 +229,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
// Check if unchanged and then possibly return old copy to avoid gc churn.
if (this fastEquals afterRule) {
- transformChildrenDown(rule)
+ transformChildren(rule, (t, r) => t.transformDown(r))
} else {
- afterRule.transformChildrenDown(rule)
+ afterRule.transformChildren(rule, (t, r) => t.transformDown(r))
}
}
@@ -240,11 +240,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
* this node. When `rule` does not apply to a given node it is left unchanged.
* @param rule the function used to transform this nodes children
*/
- def transformChildrenDown(rule: PartialFunction[BaseType, BaseType]): this.type = {
+ protected def transformChildren(
+ rule: PartialFunction[BaseType, BaseType],
+ nextOperation: (BaseType, PartialFunction[BaseType, BaseType]) => BaseType): BaseType = {
var changed = false
val newArgs = productIterator.map {
case arg: TreeNode[_] if containsChild(arg) =>
- val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
+ val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
if (!(newChild fastEquals arg)) {
changed = true
newChild
@@ -252,7 +254,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
arg
}
case Some(arg: TreeNode[_]) if containsChild(arg) =>
- val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
+ val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
if (!(newChild fastEquals arg)) {
changed = true
Some(newChild)
@@ -263,7 +265,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
case d: DataType => d // Avoid unpacking Structs
case args: Traversable[_] => args.map {
case arg: TreeNode[_] if containsChild(arg) =>
- val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
+ val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
if (!(newChild fastEquals arg)) {
changed = true
newChild
@@ -285,7 +287,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
* @param rule the function use to transform this nodes children
*/
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
- val afterRuleOnChildren = transformChildrenUp(rule)
+ val afterRuleOnChildren = transformChildren(rule, (t, r) => t.transformUp(r))
if (this fastEquals afterRuleOnChildren) {
CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(this, identity[BaseType])
@@ -297,44 +299,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
}
}
- def transformChildrenUp(rule: PartialFunction[BaseType, BaseType]): this.type = {
- var changed = false
- val newArgs = productIterator.map {
- case arg: TreeNode[_] if containsChild(arg) =>
- val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
- if (!(newChild fastEquals arg)) {
- changed = true
- newChild
- } else {
- arg
- }
- case Some(arg: TreeNode[_]) if containsChild(arg) =>
- val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
- if (!(newChild fastEquals arg)) {
- changed = true
- Some(newChild)
- } else {
- Some(arg)
- }
- case m: Map[_, _] => m
- case d: DataType => d // Avoid unpacking Structs
- case args: Traversable[_] => args.map {
- case arg: TreeNode[_] if containsChild(arg) =>
- val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
- if (!(newChild fastEquals arg)) {
- changed = true
- newChild
- } else {
- arg
- }
- case other => other
- }
- case nonChild: AnyRef => nonChild
- case null => null
- }.toArray
- if (changed) makeCopy(newArgs) else this
- }
-
/**
* Args to the constructor that should be copied, but not transformed.
* These are appended to the transformed args automatically by makeCopy
@@ -348,7 +312,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
* that are not present in the productIterator.
* @param newArgs the new product arguments.
*/
- def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") {
+ def makeCopy(newArgs: Array[AnyRef]): BaseType = attachTree(this, "makeCopy") {
val ctors = getClass.getConstructors.filter(_.getParameterTypes.size != 0)
if (ctors.isEmpty) {
sys.error(s"No valid constructor for $nodeName")
@@ -359,9 +323,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
CurrentOrigin.withOrigin(origin) {
// Skip no-arg constructors that are just there for kryo.
if (otherCopyArgs.isEmpty) {
- defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type]
+ defaultCtor.newInstance(newArgs: _*).asInstanceOf[BaseType]
} else {
- defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[this.type]
+ defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[BaseType]
}
}
} catch {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
new file mode 100644
index 0000000000..797b29f23c
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.util._
+
+/**
+ * Provides helper methods for comparing plans.
+ */
+class LogicalPlanSuite extends SparkFunSuite {
+ private var invocationCount = 0
+ private val function: PartialFunction[LogicalPlan, LogicalPlan] = {
+ case p: Project =>
+ invocationCount += 1
+ p
+ }
+
+ private val testRelation = LocalRelation()
+
+ test("resolveOperator runs on operators") {
+ invocationCount = 0
+ val plan = Project(Nil, testRelation)
+ plan resolveOperators function
+
+ assert(invocationCount === 1)
+ }
+
+ test("resolveOperator runs on operators recursively") {
+ invocationCount = 0
+ val plan = Project(Nil, Project(Nil, testRelation))
+ plan resolveOperators function
+
+ assert(invocationCount === 2)
+ }
+
+ test("resolveOperator skips all ready resolved plans") {
+ invocationCount = 0
+ val plan = Project(Nil, Project(Nil, testRelation))
+ plan.foreach(_.setAnalyzed())
+ plan resolveOperators function
+
+ assert(invocationCount === 0)
+ }
+
+ test("resolveOperator skips partially resolved plans") {
+ invocationCount = 0
+ val plan1 = Project(Nil, testRelation)
+ val plan2 = Project(Nil, plan1)
+ plan1.foreach(_.setAnalyzed())
+ plan2 resolveOperators function
+
+ assert(invocationCount === 1)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index db15711202..e57acec59d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
import java.io.CharArrayWriter
import java.util.Properties
+import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.unsafe.types.UTF8String
import scala.language.implicitConversions
@@ -54,7 +55,6 @@ private[sql] object DataFrame {
}
}
-
/**
* :: Experimental ::
* A distributed collection of data organized into named columns.
@@ -690,9 +690,7 @@ class DataFrame private[sql](
case Column(explode: Explode) => MultiAlias(explode, Nil)
case Column(expr: Expression) => Alias(expr, expr.prettyString)()
}
- // When user continuously call `select`, speed up analysis by collapsing `Project`
- import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing
- Project(namedExpressions.toSeq, ProjectCollapsing(logicalPlan))
+ Project(namedExpressions.toSeq, logicalPlan)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 73b237fffe..dbc0cefbe2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -67,7 +67,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
private val prepareCalled = new AtomicBoolean(false)
/** Overridden make copy also propogates sqlContext to copied plan. */
- override def makeCopy(newArgs: Array[AnyRef]): this.type = {
+ override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = {
SparkPlan.currentContext.set(sqlContext)
super.makeCopy(newArgs)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
index dedc7c4dfb..59f8b079ab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
@@ -65,7 +65,7 @@ private[spark] case class PythonUDF(
* multiple child operators.
*/
private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// Skip EvaluatePython nodes.
case plan: EvaluatePython => plan
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index e9dd7ef226..a88df91b10 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -25,6 +25,7 @@ import org.scalatest.concurrent.Eventually._
import org.apache.spark.Accumulators
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.columnar._
+import org.apache.spark.sql.functions._
import org.apache.spark.storage.{StorageLevel, RDDBlockId}
case class BigData(s: String)
@@ -50,6 +51,25 @@ class CachedTableSuite extends QueryTest {
ctx.sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty
}
+ test("withColumn doesn't invalidate cached dataframe") {
+ var evalCount = 0
+ val myUDF = udf((x: String) => { evalCount += 1; "result" })
+ val df = Seq(("test", 1)).toDF("s", "i").select(myUDF($"s"))
+ df.cache()
+
+ df.collect()
+ assert(evalCount === 1)
+
+ df.collect()
+ assert(evalCount === 1)
+
+ val df2 = df.withColumn("newColumn", lit(1))
+ df2.collect()
+
+ // We should not reevaluate the cached dataframe
+ assert(evalCount === 1)
+ }
+
test("cache temp table") {
testData.select('key).registerTempTable("tempTable")
assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index b8f10b00f5..f9cc6d1f3c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -686,18 +686,6 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
Seq(Row(2, 1, 2), Row(1, 1, 1)))
}
- test("SPARK-7276: Project collapse for continuous select") {
- var df = testData
- for (i <- 1 to 5) {
- df = df.select($"*")
- }
-
- import org.apache.spark.sql.catalyst.plans.logical.Project
- // make sure df have at most two Projects
- val p = df.logicalPlan.asInstanceOf[Project].child.asInstanceOf[Project]
- assert(!p.child.isInstanceOf[Project])
- }
-
test("SPARK-7150 range api") {
// numSlice is greater than length
val res1 = sqlContext.range(0, 10, 1, 15).select("id")
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index d4fc6c2b6e..ab309e0a1d 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution
import java.io.File
import java.util.{Locale, TimeZone}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.SQLConf
@@ -50,6 +51,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, 5)
// Enable in-memory partition pruning for testing purposes
TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true)
+ RuleExecutor.resetTime()
}
override def afterAll() {
@@ -58,6 +60,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
Locale.setDefault(originalLocale)
TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
+
+ // For debugging dump some statistics about how much time was spent in various optimizer rules.
+ logWarning(RuleExecutor.dumpTimeSpent())
}
/** A list of tests deemed out of scope currently and thus completely disregarded. */
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 16c186627f..6b37af99f4 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -391,7 +391,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
*/
object ParquetConversions extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
- if (!plan.resolved) {
+ if (!plan.resolved || plan.analyzed) {
return plan
}
@@ -418,8 +418,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
(relation, parquetRelation, attributedRewrites)
// Read path
- case p @ PhysicalOperation(_, _, relation: MetastoreRelation)
- if hive.convertMetastoreParquet &&
+ case relation: MetastoreRelation if hive.convertMetastoreParquet &&
relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") =>
val parquetRelation = convertToParquetRelation(relation)
val attributedRewrites = relation.output.zip(parquetRelation.output)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 8a86a87368..7182246e46 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -133,8 +133,7 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre
@transient
private lazy val conversionHelper = new ConversionHelper(method, arguments)
- @transient
- lazy val dataType = javaClassToDataType(method.getReturnType)
+ val dataType = javaClassToDataType(method.getReturnType)
@transient
lazy val returnInspector = ObjectInspectorFactory.getReflectionObjectInspector(
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index 7069afc9f7..10f2902e5e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -182,7 +182,7 @@ class HiveUDFSuite extends QueryTest {
val errMsg = intercept[AnalysisException] {
sql("SELECT testUDFToListString(s) FROM inputTable")
}
- assert(errMsg.getMessage === "List type in java is unsupported because " +
+ assert(errMsg.getMessage contains "List type in java is unsupported because " +
"JVM type erasure makes spark fail to catch a component type in List<>;")
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListString")
@@ -197,7 +197,7 @@ class HiveUDFSuite extends QueryTest {
val errMsg = intercept[AnalysisException] {
sql("SELECT testUDFToListInt(s) FROM inputTable")
}
- assert(errMsg.getMessage === "List type in java is unsupported because " +
+ assert(errMsg.getMessage contains "List type in java is unsupported because " +
"JVM type erasure makes spark fail to catch a component type in List<>;")
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListInt")
@@ -213,7 +213,7 @@ class HiveUDFSuite extends QueryTest {
val errMsg = intercept[AnalysisException] {
sql("SELECT testUDFToStringIntMap(s) FROM inputTable")
}
- assert(errMsg.getMessage === "Map type in java is unsupported because " +
+ assert(errMsg.getMessage contains "Map type in java is unsupported because " +
"JVM type erasure makes spark fail to catch key and value types in Map<>;")
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToStringIntMap")
@@ -229,7 +229,7 @@ class HiveUDFSuite extends QueryTest {
val errMsg = intercept[AnalysisException] {
sql("SELECT testUDFToIntIntMap(s) FROM inputTable")
}
- assert(errMsg.getMessage === "Map type in java is unsupported because " +
+ assert(errMsg.getMessage contains "Map type in java is unsupported because " +
"JVM type erasure makes spark fail to catch key and value types in Map<>;")
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToIntIntMap")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index ff9a3694d6..1dff07a6de 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -948,6 +948,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils {
}
test("SPARK-7595: Window will cause resolve failed with self join") {
+ sql("SELECT * FROM src") // Force loading of src table.
+
checkAnswer(sql(
"""
|with