diff options
author | Nong Li <nongli@gmail.com> | 2015-11-02 20:32:08 -0800 |
---|---|---|
committer | Yin Huai <yhuai@databricks.com> | 2015-11-02 20:32:08 -0800 |
commit | 9cb5c731dadff9539126362827a258d6b65754bb (patch) | |
tree | acc95aefbc77ba4c6557b15cdbfb6d09ec2118d4 /sql | |
parent | 2cef1bb0b560a03aa7308f694b0c66347b90c9ea (diff) | |
download | spark-9cb5c731dadff9539126362827a258d6b65754bb.tar.gz spark-9cb5c731dadff9539126362827a258d6b65754bb.tar.bz2 spark-9cb5c731dadff9539126362827a258d6b65754bb.zip |
[SPARK-11329][SQL] Support star expansion for structs.
1. Supporting expanding structs in Projections. i.e.
"SELECT s.*" where s is a struct type.
This is fixed by allowing the expand function to handle structs in addition to tables.
2. Supporting expanding * inside aggregate functions of structs.
"SELECT max(struct(col1, structCol.*))"
This requires recursively expanding the expressions. In this case, it it the aggregate
expression "max(...)" and we need to recursively expand its children inputs.
Author: Nong Li <nongli@gmail.com>
Closes #9343 from nongli/spark-11329.
Diffstat (limited to 'sql')
6 files changed, 230 insertions, 38 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 0fef043027..d7567e8613 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -466,9 +466,9 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val baseExpression: Parser[Expression] = ( "*" ^^^ UnresolvedStar(None) - | ident <~ "." ~ "*" ^^ { case tableName => UnresolvedStar(Option(tableName)) } - | primary - ) + | (ident <~ "."). + <~ "*" ^^ { case target => { UnresolvedStar(Option(target)) } + } | primary + ) protected lazy val signedPrimary: Parser[Expression] = sign ~ primary ^^ { case s ~ e => if (s == "-") UnaryMinus(e) else e } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index beabacfc88..912c967b95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -279,6 +279,24 @@ class Analyzer( * a logical plan node's children. */ object ResolveReferences extends Rule[LogicalPlan] { + /** + * Foreach expression, expands the matching attribute.*'s in `child`'s input for the subtree + * rooted at each expression. + */ + def expandStarExpressions(exprs: Seq[Expression], child: LogicalPlan): Seq[Expression] = { + exprs.flatMap { + case s: Star => s.expand(child, resolver) + case e => + e.transformDown { + case f1: UnresolvedFunction if containsStar(f1.children) => + f1.copy(children = f1.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) + } :: Nil + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: LogicalPlan if !p.childrenResolved => p @@ -286,44 +304,42 @@ class Analyzer( case p @ Project(projectList, child) if containsStar(projectList) => Project( projectList.flatMap { - case s: Star => s.expand(child.output, resolver) + case s: Star => s.expand(child, resolver) case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if containsStar(args) => - val expandedArgs = args.flatMap { - case s: Star => s.expand(child.output, resolver) - case o => o :: Nil - } - UnresolvedAlias(child = f.copy(children = expandedArgs)) :: Nil + val newChildren = expandStarExpressions(args, child) + UnresolvedAlias(child = f.copy(children = newChildren)) :: Nil + case Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) => + val newChildren = expandStarExpressions(args, child) + Alias(child = f.copy(children = newChildren), name)() :: Nil case UnresolvedAlias(c @ CreateArray(args)) if containsStar(args) => val expandedArgs = args.flatMap { - case s: Star => s.expand(child.output, resolver) + case s: Star => s.expand(child, resolver) case o => o :: Nil } UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil case UnresolvedAlias(c @ CreateStruct(args)) if containsStar(args) => val expandedArgs = args.flatMap { - case s: Star => s.expand(child.output, resolver) + case s: Star => s.expand(child, resolver) case o => o :: Nil } UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil case o => o :: Nil }, child) + case t: ScriptTransformation if containsStar(t.input) => t.copy( input = t.input.flatMap { - case s: Star => s.expand(t.child.output, resolver) + case s: Star => s.expand(t.child, resolver) case o => o :: Nil } ) // If the aggregate function argument contains Stars, expand it. case a: Aggregate if containsStar(a.aggregateExpressions) => - a.copy( - aggregateExpressions = a.aggregateExpressions.flatMap { - case s: Star => s.expand(a.child.output, resolver) - case o => o :: Nil - } - ) + val expanded = expandStarExpressions(a.aggregateExpressions, a.child) + .map(_.asInstanceOf[NamedExpression]) + a.copy(aggregateExpressions = expanded) // Special handling for cases when self-join introduce duplicate expression ids. case j @ Join(left, right, _, _) if !j.selfJoinResolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index c973650039..6975662e2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -18,12 +18,12 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.{TableIdentifier, errors} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.LeafNode +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.catalyst.{TableIdentifier, errors} +import org.apache.spark.sql.types.{DataType, StructType} /** * Thrown when an invalid attempt is made to access a property of a tree that has yet to be fully @@ -158,7 +158,7 @@ abstract class Star extends LeafExpression with NamedExpression { override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") override lazy val resolved = false - def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] + def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] } @@ -166,26 +166,68 @@ abstract class Star extends LeafExpression with NamedExpression { * Represents all of the input attributes to a given relational operator, for example in * "SELECT * FROM ...". * - * @param table an optional table that should be the target of the expansion. If omitted all - * tables' columns are produced. + * This is also used to expand structs. For example: + * "SELECT record.* from (SELECT struct(a,b,c) as record ...) + * + * @param target an optional name that should be the target of the expansion. If omitted all + * targets' columns are produced. This can either be a table name or struct name. This + * is a list of identifiers that is the path of the expansion. */ -case class UnresolvedStar(table: Option[String]) extends Star with Unevaluable { +case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevaluable { + + override def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] = { - override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = { - val expandedAttributes: Seq[Attribute] = table match { + // First try to expand assuming it is table.*. + val expandedAttributes: Seq[Attribute] = target match { // If there is no table specified, use all input attributes. - case None => input + case None => input.output // If there is a table, pick out attributes that are part of this table. - case Some(t) => input.filter(_.qualifiers.filter(resolver(_, t)).nonEmpty) + case Some(t) => if (t.size == 1) { + input.output.filter(_.qualifiers.filter(resolver(_, t.head)).nonEmpty) + } else { + List() + } } - expandedAttributes.zip(input).map { - case (n: NamedExpression, _) => n - case (e, originalAttribute) => - Alias(e, originalAttribute.name)(qualifiers = originalAttribute.qualifiers) + if (!expandedAttributes.isEmpty) { + if (expandedAttributes.forall(_.isInstanceOf[NamedExpression])) { + return expandedAttributes + } else { + require(expandedAttributes.size == input.output.size) + expandedAttributes.zip(input.output).map { + case (e, originalAttribute) => + Alias(e, originalAttribute.name)(qualifiers = originalAttribute.qualifiers) + } + } + return expandedAttributes + } + + require(target.isDefined) + + // Try to resolve it as a struct expansion. If there is a conflict and both are possible, + // (i.e. [name].* is both a table and a struct), the struct path can always be qualified. + val attribute = input.resolve(target.get, resolver) + if (attribute.isDefined) { + // This target resolved to an attribute in child. It must be a struct. Expand it. + attribute.get.dataType match { + case s: StructType => { + s.fields.map( f => { + val extract = GetStructField(attribute.get, f, s.getFieldIndex(f.name).get) + Alias(extract, target.get + "." + f.name)() + }) + } + case _ => { + throw new AnalysisException("Can only star expand struct data types. Attribute: `" + + target.get + "`") + } + } + } else { + val from = input.inputSet.map(_.name).mkString(", ") + val targetString = target.get.mkString(".") + throw new AnalysisException(s"cannot resolve '$targetString.*' give input columns '$from'") } } - override def toString: String = table.map(_ + ".").getOrElse("") + "*" + override def toString: String = target.map(_ + ".").getOrElse("") + "*" } /** @@ -225,7 +267,7 @@ case class MultiAlias(child: Expression, names: Seq[String]) * @param expressions Expressions to expand. */ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star with Unevaluable { - override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = expressions + override def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] = expressions override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index e4f4cf1533..3cde9d6cb4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -60,7 +60,8 @@ class Column(protected[sql] val expr: Expression) extends Logging { def this(name: String) = this(name match { case "*" => UnresolvedStar(None) - case _ if name.endsWith(".*") => UnresolvedStar(Some(name.substring(0, name.length - 2))) + case _ if name.endsWith(".*") => UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName( + name.substring(0, name.length - 2)))) case _ => UnresolvedAttribute.quotedString(name) }) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5413ef1287..ee54bff24b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1932,4 +1932,137 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { assert(sampled.count() == sampledOdd.count() + sampledEven.count()) } } + + test("Struct Star Expansion") { + val structDf = testData2.select("a", "b").as("record") + + checkAnswer( + structDf.select($"record.a", $"record.b"), + Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) + + checkAnswer( + structDf.select($"record.*"), + Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) + + checkAnswer( + structDf.select($"record.*", $"record.*"), + Row(1, 1, 1, 1) :: Row(1, 2, 1, 2) :: Row(2, 1, 2, 1) :: Row(2, 2, 2, 2) :: + Row(3, 1, 3, 1) :: Row(3, 2, 3, 2) :: Nil) + + checkAnswer( + sql("select struct(a, b) as r1, struct(b, a) as r2 from testData2").select($"r1.*", $"r2.*"), + Row(1, 1, 1, 1) :: Row(1, 2, 2, 1) :: Row(2, 1, 1, 2) :: Row(2, 2, 2, 2) :: + Row(3, 1, 1, 3) :: Row(3, 2, 2, 3) :: Nil) + + // Try with a registered table. + sql("select struct(a, b) as record from testData2").registerTempTable("structTable") + checkAnswer(sql("SELECT record.* FROM structTable"), + Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) + + checkAnswer(sql( + """ + | SELECT min(struct(record.*)) FROM + | (select struct(a,b) as record from testData2) tmp + """.stripMargin), + Row(Row(1, 1)) :: Nil) + + // Try with an alias on the select list + checkAnswer(sql( + """ + | SELECT max(struct(record.*)) as r FROM + | (select struct(a,b) as record from testData2) tmp + """.stripMargin).select($"r.*"), + Row(3, 2) :: Nil) + + // With GROUP BY + checkAnswer(sql( + """ + | SELECT min(struct(record.*)) FROM + | (select a as a, struct(a,b) as record from testData2) tmp + | GROUP BY a + """.stripMargin), + Row(Row(1, 1)) :: Row(Row(2, 1)) :: Row(Row(3, 1)) :: Nil) + + // With GROUP BY and alias + checkAnswer(sql( + """ + | SELECT max(struct(record.*)) as r FROM + | (select a as a, struct(a,b) as record from testData2) tmp + | GROUP BY a + """.stripMargin).select($"r.*"), + Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil) + + // With GROUP BY and alias and additional fields in the struct + checkAnswer(sql( + """ + | SELECT max(struct(a, record.*, b)) as r FROM + | (select a as a, b as b, struct(a,b) as record from testData2) tmp + | GROUP BY a + """.stripMargin).select($"r.*"), + Row(1, 1, 2, 2) :: Row(2, 2, 2, 2) :: Row(3, 3, 2, 2) :: Nil) + + // Create a data set that contains nested structs. + val nestedStructData = sql( + """ + | SELECT struct(r1, r2) as record FROM + | (SELECT struct(a, b) as r1, struct(b, a) as r2 FROM testData2) tmp + """.stripMargin) + + checkAnswer(nestedStructData.select($"record.*"), + Row(Row(1, 1), Row(1, 1)) :: Row(Row(1, 2), Row(2, 1)) :: Row(Row(2, 1), Row(1, 2)) :: + Row(Row(2, 2), Row(2, 2)) :: Row(Row(3, 1), Row(1, 3)) :: Row(Row(3, 2), Row(2, 3)) :: Nil) + checkAnswer(nestedStructData.select($"record.r1"), + Row(Row(1, 1)) :: Row(Row(1, 2)) :: Row(Row(2, 1)) :: Row(Row(2, 2)) :: + Row(Row(3, 1)) :: Row(Row(3, 2)) :: Nil) + checkAnswer( + nestedStructData.select($"record.r1.*"), + Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) + + // Try with a registered table + nestedStructData.registerTempTable("nestedStructTable") + checkAnswer(sql("SELECT record.* FROM nestedStructTable"), + nestedStructData.select($"record.*")) + checkAnswer(sql("SELECT record.r1 FROM nestedStructTable"), + nestedStructData.select($"record.r1")) + checkAnswer(sql("SELECT record.r1.* FROM nestedStructTable"), + nestedStructData.select($"record.r1.*")) + + // Create paths with unusual characters. + val specialCharacterPath = sql( + """ + | SELECT struct(`col$.a_`, `a.b.c.`) as `r&&b.c` FROM + | (SELECT struct(a, b) as `col$.a_`, struct(b, a) as `a.b.c.` FROM testData2) tmp + """.stripMargin) + specialCharacterPath.registerTempTable("specialCharacterTable") + checkAnswer(specialCharacterPath.select($"`r&&b.c`.*"), + nestedStructData.select($"record.*")) + checkAnswer(sql("SELECT `r&&b.c`.`col$.a_` FROM specialCharacterTable"), + nestedStructData.select($"record.r1")) + checkAnswer(sql("SELECT `r&&b.c`.`a.b.c.` FROM specialCharacterTable"), + nestedStructData.select($"record.r2")) + checkAnswer(sql("SELECT `r&&b.c`.`col$.a_`.* FROM specialCharacterTable"), + nestedStructData.select($"record.r1.*")) + + // Try star expanding a scalar. This should fail. + assert(intercept[AnalysisException](sql("select a.* from testData2")).getMessage.contains( + "Can only star expand struct data types.")) + + // Try resolving something not there. + assert(intercept[AnalysisException](sql("SELECT abc.* FROM nestedStructTable")) + .getMessage.contains("cannot resolve")) + } + + + test("Struct Star Expansion - Name conflict") { + // Create a data set that contains a naming conflict + val nameConflict = sql("SELECT struct(a, b) as nameConflict, a as a FROM testData2") + nameConflict.registerTempTable("nameConflict") + // Unqualified should resolve to table. + checkAnswer(sql("SELECT nameConflict.* FROM nameConflict"), + Row(Row(1, 1), 1) :: Row(Row(1, 2), 1) :: Row(Row(2, 1), 2) :: Row(Row(2, 2), 2) :: + Row(Row(3, 1), 3) :: Row(Row(3, 2), 3) :: Nil) + // Qualify the struct type with the table name. + checkAnswer(sql("SELECT nameConflict.nameConflict.* FROM nameConflict"), + Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 3697761f20..ab88c1e68f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1505,7 +1505,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only // has a single child which is tableName. case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) => - UnresolvedStar(Some(name)) + UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName(name))) /* Aggregate Functions */ case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => Count(Literal(1)) |