diff options
author | gatorsmile <gatorsmile@gmail.com> | 2016-03-24 11:13:36 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2016-03-24 11:13:36 +0800 |
commit | f42eaf42bdca8bc6f390f1f31ee60faa1662489b (patch) | |
tree | d5f84d22eb06d95fd4b20a79a6ccaaa34c911d89 /sql/catalyst | |
parent | de4e48b62b998d45d4a749234741a45534719497 (diff) | |
download | spark-f42eaf42bdca8bc6f390f1f31ee60faa1662489b.tar.gz spark-f42eaf42bdca8bc6f390f1f31ee60faa1662489b.tar.bz2 spark-f42eaf42bdca8bc6f390f1f31ee60faa1662489b.zip |
[SPARK-14085][SQL] Star Expansion for Hash
#### What changes were proposed in this pull request?
This PR is to support star expansion in hash. For example,
```SQL
val structDf = testData2.select("a", "b").as("record")
structDf.select(hash($"*")
```
In addition, it refactors the codes for the rule `ResolveStar` and fixes a regression for star expansion in group by when using SQL API. For example,
```SQL
SELECT * FROM testData2 group by a, b
```
cc cloud-fan Now, the code for star resolution is much cleaner. The coverage is better. Could you check if this refactoring is good? Thanks!
#### How was this patch tested?
Added a few test cases to cover it.
Author: gatorsmile <gatorsmile@gmail.com>
Closes #11904 from gatorsmile/starResolution.
Diffstat (limited to 'sql/catalyst')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 40 |
1 files changed, 23 insertions, 17 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 178e9402fa..54543eebb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -380,27 +380,12 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: LogicalPlan if !p.childrenResolved => p - // If the projection list contains Stars, expand it. case p: Project if containsStar(p.projectList) => - val expanded = p.projectList.flatMap { - case s: Star => s.expand(p.child, resolver) - case ua @ UnresolvedAlias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) => - UnresolvedAlias(child = expandStarExpression(ua.child, p.child)) :: Nil - case a @ Alias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) => - a.withNewChildren(expandStarExpression(a.child, p.child) :: Nil) - .asInstanceOf[Alias] :: Nil - case o => o :: Nil - } - Project(projectList = expanded, p.child) + p.copy(projectList = buildExpandedProjectList(p.projectList, p.child)) // If the aggregate function argument contains Stars, expand it. case a: Aggregate if containsStar(a.aggregateExpressions) => - val expanded = a.aggregateExpressions.flatMap { - case s: Star => s.expand(a.child, resolver) - case o if containsStar(o :: Nil) => expandStarExpression(o, a.child) :: Nil - case o => o :: Nil - }.map(_.asInstanceOf[NamedExpression]) - a.copy(aggregateExpressions = expanded) + a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child)) // If the script transformation input contains Stars, expand it. case t: ScriptTransformation if containsStar(t.input) => t.copy( @@ -414,6 +399,22 @@ class Analyzer( } /** + * Build a project list for Project/Aggregate and expand the star if possible + */ + private def buildExpandedProjectList( + exprs: Seq[NamedExpression], + child: LogicalPlan): Seq[NamedExpression] = { + exprs.flatMap { + // Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*") + case s: Star => s.expand(child, resolver) + // Using SQL API without running ResolveAlias: SELECT * FROM testData2 group by a, b + case UnresolvedAlias(s: Star, _) => s.expand(child, resolver) + case o if containsStar(o :: Nil) => expandStarExpression(o, child) :: Nil + case o => o :: Nil + }.map(_.asInstanceOf[NamedExpression]) + } + + /** * Returns true if `exprs` contains a [[Star]]. */ def containsStar(exprs: Seq[Expression]): Boolean = @@ -439,6 +440,11 @@ class Analyzer( case s: Star => s.expand(child, resolver) case o => o :: Nil }) + case p: Murmur3Hash if containsStar(p.children) => + p.copy(children = p.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) // count(*) has been replaced by count(1) case o if containsStar(o.children) => failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'") |