diff options
author | gatorsmile <gatorsmile@gmail.com> | 2016-03-22 08:21:02 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2016-03-22 08:21:02 +0800 |
commit | 3f49e0766f3a369a44e14632de68c657773b7a27 (patch) | |
tree | 98a493a351dc476656eff031d4f97109ffeed0e0 /sql/catalyst | |
parent | b5f1ab701a167a728bb006e01b392b203da84391 (diff) | |
download | spark-3f49e0766f3a369a44e14632de68c657773b7a27.tar.gz spark-3f49e0766f3a369a44e14632de68c657773b7a27.tar.bz2 spark-3f49e0766f3a369a44e14632de68c657773b7a27.zip |
[SPARK-13320][SQL] Support Star in CreateStruct/CreateArray and Error Handling when DataFrame/DataSet Functions using Star
This PR resolves two issues:
First, expanding * inside aggregate functions of structs when using Dataframe/Dataset APIs. For example,
```scala
structDf.groupBy($"a").agg(min(struct($"record.*")))
```
Second, it improves the error messages when having invalid star usage when using Dataframe/Dataset APIs. For example,
```scala
pagecounts4PartitionsDS
.map(line => (line._1, line._3))
.toDF()
.groupBy($"_1")
.agg(sum("*") as "sumOccurances")
```
Before the fix, the invalid usage will issue a confusing error message, like:
```
org.apache.spark.sql.AnalysisException: cannot resolve '_1' given input columns _1, _2;
```
After the fix, the message is like:
```
org.apache.spark.sql.AnalysisException: Invalid usage of '*' in function 'sum'
```
cc: rxin nongli cloud-fan
Author: gatorsmile <gatorsmile@gmail.com>
Closes #11208 from gatorsmile/sumDataSetResolution.
Diffstat (limited to 'sql/catalyst')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 138 | ||||
-rw-r--r-- | sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala | 5 |
2 files changed, 77 insertions, 66 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7a08c7dcfc..5951a70c48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -80,6 +80,7 @@ class Analyzer( EliminateUnions), Batch("Resolution", fixedPoint, ResolveRelations :: + ResolveStar :: ResolveReferences :: ResolveGroupingAnalytics :: ResolvePivot :: @@ -373,28 +374,83 @@ class Analyzer( } /** - * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from - * a logical plan node's children. + * Expand [[UnresolvedStar]] or [[ResolvedStar]] to the matching attributes in child's output. */ - object ResolveReferences extends Rule[LogicalPlan] { + object ResolveStar extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p: LogicalPlan if !p.childrenResolved => p + + // If the projection list contains Stars, expand it. + case p: Project if containsStar(p.projectList) => + val expanded = p.projectList.flatMap { + case s: Star => s.expand(p.child, resolver) + case ua @ UnresolvedAlias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) => + UnresolvedAlias(child = expandStarExpression(ua.child, p.child)) :: Nil + case a @ Alias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) => + a.withNewChildren(expandStarExpression(a.child, p.child) :: Nil) + .asInstanceOf[Alias] :: Nil + case o => o :: Nil + } + Project(projectList = expanded, p.child) + // If the aggregate function argument contains Stars, expand it. + case a: Aggregate if containsStar(a.aggregateExpressions) => + val expanded = a.aggregateExpressions.flatMap { + case s: Star => s.expand(a.child, resolver) + case o if containsStar(o :: Nil) => expandStarExpression(o, a.child) :: Nil + case o => o :: Nil + }.map(_.asInstanceOf[NamedExpression]) + a.copy(aggregateExpressions = expanded) + // If the script transformation input contains Stars, expand it. + case t: ScriptTransformation if containsStar(t.input) => + t.copy( + input = t.input.flatMap { + case s: Star => s.expand(t.child, resolver) + case o => o :: Nil + } + ) + case g: Generate if containsStar(g.generator.children) => + failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF") + } + + /** + * Returns true if `exprs` contains a [[Star]]. + */ + def containsStar(exprs: Seq[Expression]): Boolean = + exprs.exists(_.collect { case _: Star => true }.nonEmpty) + /** - * Foreach expression, expands the matching attribute.*'s in `child`'s input for the subtree - * rooted at each expression. + * Expands the matching attribute.*'s in `child`'s output. */ - def expandStarExpressions(exprs: Seq[Expression], child: LogicalPlan): Seq[Expression] = { - exprs.flatMap { - case s: Star => s.expand(child, resolver) - case e => - e.transformDown { - case f1: UnresolvedFunction if containsStar(f1.children) => - f1.copy(children = f1.children.flatMap { - case s: Star => s.expand(child, resolver) - case o => o :: Nil - }) - } :: Nil + def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = { + expr.transformUp { + case f1: UnresolvedFunction if containsStar(f1.children) => + f1.copy(children = f1.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) + case c: CreateStruct if containsStar(c.children) => + c.copy(children = c.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) + case c: CreateArray if containsStar(c.children) => + c.copy(children = c.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) + // count(*) has been replaced by count(1) + case o if containsStar(o.children) => + failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'") } } + } + /** + * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from + * a logical plan node's children. + */ + object ResolveReferences extends Rule[LogicalPlan] { /** * Generate a new logical plan for the right child with different expression IDs * for all conflicting attributes. @@ -456,48 +512,6 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: LogicalPlan if !p.childrenResolved => p - // If the projection list contains Stars, expand it. - case p @ Project(projectList, child) if containsStar(projectList) => - Project( - projectList.flatMap { - case s: Star => s.expand(child, resolver) - case UnresolvedAlias(f @ UnresolvedFunction(_, args, _), _) if containsStar(args) => - val newChildren = expandStarExpressions(args, child) - UnresolvedAlias(child = f.copy(children = newChildren)) :: Nil - case a @ Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) => - val newChildren = expandStarExpressions(args, child) - Alias(child = f.copy(children = newChildren), name)( - isGenerated = a.isGenerated) :: Nil - case UnresolvedAlias(c @ CreateArray(args), _) if containsStar(args) => - val expandedArgs = args.flatMap { - case s: Star => s.expand(child, resolver) - case o => o :: Nil - } - UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil - case UnresolvedAlias(c @ CreateStruct(args), _) if containsStar(args) => - val expandedArgs = args.flatMap { - case s: Star => s.expand(child, resolver) - case o => o :: Nil - } - UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil - case o => o :: Nil - }, - child) - - case t: ScriptTransformation if containsStar(t.input) => - t.copy( - input = t.input.flatMap { - case s: Star => s.expand(t.child, resolver) - case o => o :: Nil - } - ) - - // If the aggregate function argument contains Stars, expand it. - case a: Aggregate if containsStar(a.aggregateExpressions) => - val expanded = expandStarExpressions(a.aggregateExpressions, a.child) - .map(_.asInstanceOf[NamedExpression]) - a.copy(aggregateExpressions = expanded) - // To resolve duplicate expression IDs for Join and Intersect case j @ Join(left, right, _, _) if !j.duplicateResolved => j.copy(right = dedupRight(left, right)) @@ -592,12 +606,6 @@ class Analyzer( def findAliases(projectList: Seq[NamedExpression]): AttributeSet = { AttributeSet(projectList.collect { case a: Alias => a.toAttribute }) } - - /** - * Returns true if `exprs` contains a [[Star]]. - */ - def containsStar(exprs: Seq[Expression]): Boolean = - exprs.exists(_.collect { case _: Star => true }.nonEmpty) } protected[sql] def resolveExpression( @@ -923,8 +931,6 @@ class Analyzer( */ object ResolveGenerate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case g: Generate if ResolveReferences.containsStar(g.generator.children) => - failAnalysis("Cannot explode *, explode can only be applied on a specific column.") case p: Generate if !p.child.resolved || !p.generator.resolved => p case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 1b297525bd..c87a2e24bd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -190,6 +190,11 @@ class AnalysisErrorSuite extends AnalysisTest { "cannot resolve" :: "havingCondition" :: Nil) errorTest( + "unresolved star expansion in max", + testRelation2.groupBy('a)(sum(UnresolvedStar(None))), + "Invalid usage of '*'" :: "in expression 'sum'" :: Nil) + + errorTest( "bad casts", testRelation.select(Literal(1).cast(BinaryType).as('badCast)), "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) |