aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala162
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala28
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala9
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala17
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala7
-rw-r--r--sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala4
7 files changed, 124 insertions, 105 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 89b18af9a0..3b83e68018 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -80,7 +80,6 @@ class Analyzer(
EliminateUnions),
Batch("Resolution", fixedPoint,
ResolveRelations ::
- ResolveStar ::
ResolveReferences ::
ResolveGroupingAnalytics ::
ResolvePivot ::
@@ -375,91 +374,6 @@ class Analyzer(
}
/**
- * Expand [[UnresolvedStar]] or [[ResolvedStar]] to the matching attributes in child's output.
- */
- object ResolveStar extends Rule[LogicalPlan] {
-
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
- case p: LogicalPlan if !p.childrenResolved => p
- // If the projection list contains Stars, expand it.
- case p: Project if containsStar(p.projectList) =>
- p.copy(projectList = buildExpandedProjectList(p.projectList, p.child))
- // If the aggregate function argument contains Stars, expand it.
- case a: Aggregate if containsStar(a.aggregateExpressions) =>
- if (conf.groupByOrdinal && a.groupingExpressions.exists(IntegerIndex.unapply(_).nonEmpty)) {
- failAnalysis(
- "Group by position: star is not allowed to use in the select list " +
- "when using ordinals in group by")
- } else {
- a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child))
- }
- // If the script transformation input contains Stars, expand it.
- case t: ScriptTransformation if containsStar(t.input) =>
- t.copy(
- input = t.input.flatMap {
- case s: Star => s.expand(t.child, resolver)
- case o => o :: Nil
- }
- )
- case g: Generate if containsStar(g.generator.children) =>
- failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF")
- }
-
- /**
- * Build a project list for Project/Aggregate and expand the star if possible
- */
- private def buildExpandedProjectList(
- exprs: Seq[NamedExpression],
- child: LogicalPlan): Seq[NamedExpression] = {
- exprs.flatMap {
- // Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*")
- case s: Star => s.expand(child, resolver)
- // Using SQL API without running ResolveAlias: SELECT * FROM testData2 group by a, b
- case UnresolvedAlias(s: Star, _) => s.expand(child, resolver)
- case o if containsStar(o :: Nil) => expandStarExpression(o, child) :: Nil
- case o => o :: Nil
- }.map(_.asInstanceOf[NamedExpression])
- }
-
- /**
- * Returns true if `exprs` contains a [[Star]].
- */
- def containsStar(exprs: Seq[Expression]): Boolean =
- exprs.exists(_.collect { case _: Star => true }.nonEmpty)
-
- /**
- * Expands the matching attribute.*'s in `child`'s output.
- */
- def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = {
- expr.transformUp {
- case f1: UnresolvedFunction if containsStar(f1.children) =>
- f1.copy(children = f1.children.flatMap {
- case s: Star => s.expand(child, resolver)
- case o => o :: Nil
- })
- case c: CreateStruct if containsStar(c.children) =>
- c.copy(children = c.children.flatMap {
- case s: Star => s.expand(child, resolver)
- case o => o :: Nil
- })
- case c: CreateArray if containsStar(c.children) =>
- c.copy(children = c.children.flatMap {
- case s: Star => s.expand(child, resolver)
- case o => o :: Nil
- })
- case p: Murmur3Hash if containsStar(p.children) =>
- p.copy(children = p.children.flatMap {
- case s: Star => s.expand(child, resolver)
- case o => o :: Nil
- })
- // count(*) has been replaced by count(1)
- case o if containsStar(o.children) =>
- failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'")
- }
- }
- }
-
- /**
* Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from
* a logical plan node's children.
*/
@@ -525,6 +439,29 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p: LogicalPlan if !p.childrenResolved => p
+ // If the projection list contains Stars, expand it.
+ case p: Project if containsStar(p.projectList) =>
+ p.copy(projectList = buildExpandedProjectList(p.projectList, p.child))
+ // If the aggregate function argument contains Stars, expand it.
+ case a: Aggregate if containsStar(a.aggregateExpressions) =>
+ if (conf.groupByOrdinal && a.groupingExpressions.exists(IntegerIndex.unapply(_).nonEmpty)) {
+ failAnalysis(
+ "Group by position: star is not allowed to use in the select list " +
+ "when using ordinals in group by")
+ } else {
+ a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child))
+ }
+ // If the script transformation input contains Stars, expand it.
+ case t: ScriptTransformation if containsStar(t.input) =>
+ t.copy(
+ input = t.input.flatMap {
+ case s: Star => s.expand(t.child, resolver)
+ case o => o :: Nil
+ }
+ )
+ case g: Generate if containsStar(g.generator.children) =>
+ failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF")
+
// To resolve duplicate expression IDs for Join and Intersect
case j @ Join(left, right, _, _) if !j.duplicateResolved =>
j.copy(right = dedupRight(left, right))
@@ -619,6 +556,59 @@ class Analyzer(
def findAliases(projectList: Seq[NamedExpression]): AttributeSet = {
AttributeSet(projectList.collect { case a: Alias => a.toAttribute })
}
+
+ /**
+ * Build a project list for Project/Aggregate and expand the star if possible
+ */
+ private def buildExpandedProjectList(
+ exprs: Seq[NamedExpression],
+ child: LogicalPlan): Seq[NamedExpression] = {
+ exprs.flatMap {
+ // Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*")
+ case s: Star => s.expand(child, resolver)
+ // Using SQL API without running ResolveAlias: SELECT * FROM testData2 group by a, b
+ case UnresolvedAlias(s: Star, _) => s.expand(child, resolver)
+ case o if containsStar(o :: Nil) => expandStarExpression(o, child) :: Nil
+ case o => o :: Nil
+ }.map(_.asInstanceOf[NamedExpression])
+ }
+
+ /**
+ * Returns true if `exprs` contains a [[Star]].
+ */
+ def containsStar(exprs: Seq[Expression]): Boolean =
+ exprs.exists(_.collect { case _: Star => true }.nonEmpty)
+
+ /**
+ * Expands the matching attribute.*'s in `child`'s output.
+ */
+ def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = {
+ expr.transformUp {
+ case f1: UnresolvedFunction if containsStar(f1.children) =>
+ f1.copy(children = f1.children.flatMap {
+ case s: Star => s.expand(child, resolver)
+ case o => o :: Nil
+ })
+ case c: CreateStruct if containsStar(c.children) =>
+ c.copy(children = c.children.flatMap {
+ case s: Star => s.expand(child, resolver)
+ case o => o :: Nil
+ })
+ case c: CreateArray if containsStar(c.children) =>
+ c.copy(children = c.children.flatMap {
+ case s: Star => s.expand(child, resolver)
+ case o => o :: Nil
+ })
+ case p: Murmur3Hash if containsStar(p.children) =>
+ p.copy(children = p.children.flatMap {
+ case s: Star => s.expand(child, resolver)
+ case o => o :: Nil
+ })
+ // count(*) has been replaced by count(1)
+ case o if containsStar(o.children) =>
+ failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'")
+ }
+ }
}
protected[sql] def resolveExpression(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 4cfdcf95cb..a7a948ef1b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -306,21 +306,21 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}
/**
- * Attempts to eliminate the reading of unneeded columns from the query plan using the following
- * transformations:
+ * Attempts to eliminate the reading of unneeded columns from the query plan.
*
- * - Inserting Projections beneath the following operators:
- * - Aggregate
- * - Generate
- * - Project <- Join
- * - LeftSemiJoin
+ * Since adding Project before Filter conflicts with PushPredicatesThroughProject, this rule will
+ * remove the Project p2 in the following pattern:
+ *
+ * p1 @ Project(_, Filter(_, p2 @ Project(_, child))) if p2.outputSet.subsetOf(p2.inputSet)
+ *
+ * p2 is usually inserted by this rule and useless, p1 could prune the columns anyway.
*/
object ColumnPruning extends Rule[LogicalPlan] {
private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean =
output1.size == output2.size &&
output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2))
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = removeProjectBeforeFilter(plan transform {
// Prunes the unused columns from project list of Project/Aggregate/Expand
case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty =>
p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains)))
@@ -399,7 +399,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
} else {
p
}
- }
+ })
/** Applies a projection only when the child is producing unnecessary attributes */
private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
@@ -408,6 +408,16 @@ object ColumnPruning extends Rule[LogicalPlan] {
} else {
c
}
+
+ /**
+ * The Project before Filter is not necessary but conflict with PushPredicatesThroughProject,
+ * so remove it.
+ */
+ private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transform {
+ case p1 @ Project(_, f @ Filter(_, p2 @ Project(_, child)))
+ if p2.outputSet.subsetOf(child.outputSet) =>
+ p1.copy(child = f.copy(child = child))
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
index 8e30349f50..6fc828f63f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
@@ -22,8 +22,10 @@ import scala.collection.JavaConverters._
import com.google.common.util.concurrent.AtomicLongMap
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.util.sideBySide
+import org.apache.spark.util.Utils
object RuleExecutor {
protected val timeMap = AtomicLongMap.create[String]()
@@ -98,7 +100,12 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
if (iteration > batch.strategy.maxIterations) {
// Only log if this is a rule that is supposed to run more than once.
if (iteration != 2) {
- logInfo(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}")
+ val message = s"Max iterations (${iteration - 1}) reached for batch ${batch.name}"
+ if (Utils.isTesting) {
+ throw new TreeNodeException(curPlan, message, null)
+ } else {
+ logWarning(message)
+ }
}
continue = false
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 346e052437..a63d1770f3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -29,7 +29,7 @@ class AnalysisSuite extends AnalysisTest {
import org.apache.spark.sql.catalyst.analysis.TestRelations._
test("union project *") {
- val plan = (1 to 100)
+ val plan = (1 to 120)
.map(_ => testRelation)
.fold[LogicalPlan](testRelation) { (a, b) =>
a.select(UnresolvedStar(None)).select('a).union(b.select(UnresolvedStar(None)))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index dd7d65ddc9..2248e03b2f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -34,6 +34,7 @@ class ColumnPruningSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("Column pruning", FixedPoint(100),
+ PushPredicateThroughProject,
ColumnPruning,
CollapseProject) :: Nil
}
@@ -133,12 +134,16 @@ class ColumnPruningSuite extends PlanTest {
test("Column pruning on Filter") {
val input = LocalRelation('a.int, 'b.string, 'c.double)
+ val plan1 = Filter('a > 1, input).analyze
+ comparePlans(Optimize.execute(plan1), plan1)
val query = Project('a :: Nil, Filter('c > Literal(0.0), input)).analyze
- val expected =
- Project('a :: Nil,
- Filter('c > Literal(0.0),
- Project(Seq('a, 'c), input))).analyze
- comparePlans(Optimize.execute(query), expected)
+ comparePlans(Optimize.execute(query), query)
+ val plan2 = Filter('b > 1, Project(Seq('a, 'b), input)).analyze
+ val expected2 = Project(Seq('a, 'b), Filter('b > 1, input)).analyze
+ comparePlans(Optimize.execute(plan2), expected2)
+ val plan3 = Project(Seq('a), Filter('b > 1, Project(Seq('a, 'b), input))).analyze
+ val expected3 = Project(Seq('a), Filter('b > 1, input)).analyze
+ comparePlans(Optimize.execute(plan3), expected3)
}
test("Column pruning on except/intersect/distinct") {
@@ -297,7 +302,7 @@ class ColumnPruningSuite extends PlanTest {
SortOrder('b, Ascending) :: Nil,
UnspecifiedFrame)).as('window) :: Nil,
'a :: Nil, 'b.asc :: Nil)
- .select('a, 'c, 'window).where('window > 1).select('a, 'c).analyze
+ .where('window > 1).select('a, 'c).analyze
val optimized = Optimize.execute(originalQuery.analyze)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala
index a7de7b052b..c9d36910b0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala
@@ -18,7 +18,9 @@
package org.apache.spark.sql.catalyst.trees
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
class RuleExecutorSuite extends SparkFunSuite {
@@ -49,6 +51,9 @@ class RuleExecutorSuite extends SparkFunSuite {
val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil
}
- assert(ToFixedPoint.execute(Literal(100)) === Literal(90))
+ val message = intercept[TreeNodeException[LogicalPlan]] {
+ ToFixedPoint.execute(Literal(100))
+ }.getMessage
+ assert(message.contains("Max iterations (10) reached for batch fixedPoint"))
}
}
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 650797f768..8bd731dda2 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -341,6 +341,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_round_3",
"view_cast",
+ // enable this after fixing SPARK-14137
+ "union20",
+
// These tests check the VIEW table definition, but Spark handles CREATE VIEW itself and
// generates different View Expanded Text.
"alter_view_as_select",
@@ -1043,7 +1046,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"union18",
"union19",
"union2",
- "union20",
"union22",
"union23",
"union24",