aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorNong Li <nongli@gmail.com>2015-11-02 20:32:08 -0800
committerYin Huai <yhuai@databricks.com>2015-11-02 20:32:08 -0800
commit9cb5c731dadff9539126362827a258d6b65754bb (patch)
treeacc95aefbc77ba4c6557b15cdbfb6d09ec2118d4 /sql
parent2cef1bb0b560a03aa7308f694b0c66347b90c9ea (diff)
downloadspark-9cb5c731dadff9539126362827a258d6b65754bb.tar.gz
spark-9cb5c731dadff9539126362827a258d6b65754bb.tar.bz2
spark-9cb5c731dadff9539126362827a258d6b65754bb.zip
[SPARK-11329][SQL] Support star expansion for structs.
1. Supporting expanding structs in Projections. i.e. "SELECT s.*" where s is a struct type. This is fixed by allowing the expand function to handle structs in addition to tables. 2. Supporting expanding * inside aggregate functions of structs. "SELECT max(struct(col1, structCol.*))" This requires recursively expanding the expressions. In this case, it it the aggregate expression "max(...)" and we need to recursively expand its children inputs. Author: Nong Li <nongli@gmail.com> Closes #9343 from nongli/spark-11329.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala46
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala78
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala133
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala2
6 files changed, 230 insertions, 38 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 0fef043027..d7567e8613 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -466,9 +466,9 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected lazy val baseExpression: Parser[Expression] =
( "*" ^^^ UnresolvedStar(None)
- | ident <~ "." ~ "*" ^^ { case tableName => UnresolvedStar(Option(tableName)) }
- | primary
- )
+ | (ident <~ "."). + <~ "*" ^^ { case target => { UnresolvedStar(Option(target)) }
+ } | primary
+ )
protected lazy val signedPrimary: Parser[Expression] =
sign ~ primary ^^ { case s ~ e => if (s == "-") UnaryMinus(e) else e }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index beabacfc88..912c967b95 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -279,6 +279,24 @@ class Analyzer(
* a logical plan node's children.
*/
object ResolveReferences extends Rule[LogicalPlan] {
+ /**
+ * Foreach expression, expands the matching attribute.*'s in `child`'s input for the subtree
+ * rooted at each expression.
+ */
+ def expandStarExpressions(exprs: Seq[Expression], child: LogicalPlan): Seq[Expression] = {
+ exprs.flatMap {
+ case s: Star => s.expand(child, resolver)
+ case e =>
+ e.transformDown {
+ case f1: UnresolvedFunction if containsStar(f1.children) =>
+ f1.copy(children = f1.children.flatMap {
+ case s: Star => s.expand(child, resolver)
+ case o => o :: Nil
+ })
+ } :: Nil
+ }
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p: LogicalPlan if !p.childrenResolved => p
@@ -286,44 +304,42 @@ class Analyzer(
case p @ Project(projectList, child) if containsStar(projectList) =>
Project(
projectList.flatMap {
- case s: Star => s.expand(child.output, resolver)
+ case s: Star => s.expand(child, resolver)
case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if containsStar(args) =>
- val expandedArgs = args.flatMap {
- case s: Star => s.expand(child.output, resolver)
- case o => o :: Nil
- }
- UnresolvedAlias(child = f.copy(children = expandedArgs)) :: Nil
+ val newChildren = expandStarExpressions(args, child)
+ UnresolvedAlias(child = f.copy(children = newChildren)) :: Nil
+ case Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) =>
+ val newChildren = expandStarExpressions(args, child)
+ Alias(child = f.copy(children = newChildren), name)() :: Nil
case UnresolvedAlias(c @ CreateArray(args)) if containsStar(args) =>
val expandedArgs = args.flatMap {
- case s: Star => s.expand(child.output, resolver)
+ case s: Star => s.expand(child, resolver)
case o => o :: Nil
}
UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
case UnresolvedAlias(c @ CreateStruct(args)) if containsStar(args) =>
val expandedArgs = args.flatMap {
- case s: Star => s.expand(child.output, resolver)
+ case s: Star => s.expand(child, resolver)
case o => o :: Nil
}
UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
case o => o :: Nil
},
child)
+
case t: ScriptTransformation if containsStar(t.input) =>
t.copy(
input = t.input.flatMap {
- case s: Star => s.expand(t.child.output, resolver)
+ case s: Star => s.expand(t.child, resolver)
case o => o :: Nil
}
)
// If the aggregate function argument contains Stars, expand it.
case a: Aggregate if containsStar(a.aggregateExpressions) =>
- a.copy(
- aggregateExpressions = a.aggregateExpressions.flatMap {
- case s: Star => s.expand(a.child.output, resolver)
- case o => o :: Nil
- }
- )
+ val expanded = expandStarExpressions(a.aggregateExpressions, a.child)
+ .map(_.asInstanceOf[NamedExpression])
+ a.copy(aggregateExpressions = expanded)
// Special handling for cases when self-join introduce duplicate expression ids.
case j @ Join(left, right, _, _) if !j.selfJoinResolved =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index c973650039..6975662e2b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -18,12 +18,12 @@
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.catalyst.{TableIdentifier, errors}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.LeafNode
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode}
import org.apache.spark.sql.catalyst.trees.TreeNode
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.catalyst.{TableIdentifier, errors}
+import org.apache.spark.sql.types.{DataType, StructType}
/**
* Thrown when an invalid attempt is made to access a property of a tree that has yet to be fully
@@ -158,7 +158,7 @@ abstract class Star extends LeafExpression with NamedExpression {
override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
override lazy val resolved = false
- def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression]
+ def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression]
}
@@ -166,26 +166,68 @@ abstract class Star extends LeafExpression with NamedExpression {
* Represents all of the input attributes to a given relational operator, for example in
* "SELECT * FROM ...".
*
- * @param table an optional table that should be the target of the expansion. If omitted all
- * tables' columns are produced.
+ * This is also used to expand structs. For example:
+ * "SELECT record.* from (SELECT struct(a,b,c) as record ...)
+ *
+ * @param target an optional name that should be the target of the expansion. If omitted all
+ * targets' columns are produced. This can either be a table name or struct name. This
+ * is a list of identifiers that is the path of the expansion.
*/
-case class UnresolvedStar(table: Option[String]) extends Star with Unevaluable {
+case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevaluable {
+
+ override def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] = {
- override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = {
- val expandedAttributes: Seq[Attribute] = table match {
+ // First try to expand assuming it is table.*.
+ val expandedAttributes: Seq[Attribute] = target match {
// If there is no table specified, use all input attributes.
- case None => input
+ case None => input.output
// If there is a table, pick out attributes that are part of this table.
- case Some(t) => input.filter(_.qualifiers.filter(resolver(_, t)).nonEmpty)
+ case Some(t) => if (t.size == 1) {
+ input.output.filter(_.qualifiers.filter(resolver(_, t.head)).nonEmpty)
+ } else {
+ List()
+ }
}
- expandedAttributes.zip(input).map {
- case (n: NamedExpression, _) => n
- case (e, originalAttribute) =>
- Alias(e, originalAttribute.name)(qualifiers = originalAttribute.qualifiers)
+ if (!expandedAttributes.isEmpty) {
+ if (expandedAttributes.forall(_.isInstanceOf[NamedExpression])) {
+ return expandedAttributes
+ } else {
+ require(expandedAttributes.size == input.output.size)
+ expandedAttributes.zip(input.output).map {
+ case (e, originalAttribute) =>
+ Alias(e, originalAttribute.name)(qualifiers = originalAttribute.qualifiers)
+ }
+ }
+ return expandedAttributes
+ }
+
+ require(target.isDefined)
+
+ // Try to resolve it as a struct expansion. If there is a conflict and both are possible,
+ // (i.e. [name].* is both a table and a struct), the struct path can always be qualified.
+ val attribute = input.resolve(target.get, resolver)
+ if (attribute.isDefined) {
+ // This target resolved to an attribute in child. It must be a struct. Expand it.
+ attribute.get.dataType match {
+ case s: StructType => {
+ s.fields.map( f => {
+ val extract = GetStructField(attribute.get, f, s.getFieldIndex(f.name).get)
+ Alias(extract, target.get + "." + f.name)()
+ })
+ }
+ case _ => {
+ throw new AnalysisException("Can only star expand struct data types. Attribute: `" +
+ target.get + "`")
+ }
+ }
+ } else {
+ val from = input.inputSet.map(_.name).mkString(", ")
+ val targetString = target.get.mkString(".")
+ throw new AnalysisException(s"cannot resolve '$targetString.*' give input columns '$from'")
}
}
- override def toString: String = table.map(_ + ".").getOrElse("") + "*"
+ override def toString: String = target.map(_ + ".").getOrElse("") + "*"
}
/**
@@ -225,7 +267,7 @@ case class MultiAlias(child: Expression, names: Seq[String])
* @param expressions Expressions to expand.
*/
case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star with Unevaluable {
- override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = expressions
+ override def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] = expressions
override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index e4f4cf1533..3cde9d6cb4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -60,7 +60,8 @@ class Column(protected[sql] val expr: Expression) extends Logging {
def this(name: String) = this(name match {
case "*" => UnresolvedStar(None)
- case _ if name.endsWith(".*") => UnresolvedStar(Some(name.substring(0, name.length - 2)))
+ case _ if name.endsWith(".*") => UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName(
+ name.substring(0, name.length - 2))))
case _ => UnresolvedAttribute.quotedString(name)
})
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 5413ef1287..ee54bff24b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1932,4 +1932,137 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
assert(sampled.count() == sampledOdd.count() + sampledEven.count())
}
}
+
+ test("Struct Star Expansion") {
+ val structDf = testData2.select("a", "b").as("record")
+
+ checkAnswer(
+ structDf.select($"record.a", $"record.b"),
+ Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil)
+
+ checkAnswer(
+ structDf.select($"record.*"),
+ Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil)
+
+ checkAnswer(
+ structDf.select($"record.*", $"record.*"),
+ Row(1, 1, 1, 1) :: Row(1, 2, 1, 2) :: Row(2, 1, 2, 1) :: Row(2, 2, 2, 2) ::
+ Row(3, 1, 3, 1) :: Row(3, 2, 3, 2) :: Nil)
+
+ checkAnswer(
+ sql("select struct(a, b) as r1, struct(b, a) as r2 from testData2").select($"r1.*", $"r2.*"),
+ Row(1, 1, 1, 1) :: Row(1, 2, 2, 1) :: Row(2, 1, 1, 2) :: Row(2, 2, 2, 2) ::
+ Row(3, 1, 1, 3) :: Row(3, 2, 2, 3) :: Nil)
+
+ // Try with a registered table.
+ sql("select struct(a, b) as record from testData2").registerTempTable("structTable")
+ checkAnswer(sql("SELECT record.* FROM structTable"),
+ Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil)
+
+ checkAnswer(sql(
+ """
+ | SELECT min(struct(record.*)) FROM
+ | (select struct(a,b) as record from testData2) tmp
+ """.stripMargin),
+ Row(Row(1, 1)) :: Nil)
+
+ // Try with an alias on the select list
+ checkAnswer(sql(
+ """
+ | SELECT max(struct(record.*)) as r FROM
+ | (select struct(a,b) as record from testData2) tmp
+ """.stripMargin).select($"r.*"),
+ Row(3, 2) :: Nil)
+
+ // With GROUP BY
+ checkAnswer(sql(
+ """
+ | SELECT min(struct(record.*)) FROM
+ | (select a as a, struct(a,b) as record from testData2) tmp
+ | GROUP BY a
+ """.stripMargin),
+ Row(Row(1, 1)) :: Row(Row(2, 1)) :: Row(Row(3, 1)) :: Nil)
+
+ // With GROUP BY and alias
+ checkAnswer(sql(
+ """
+ | SELECT max(struct(record.*)) as r FROM
+ | (select a as a, struct(a,b) as record from testData2) tmp
+ | GROUP BY a
+ """.stripMargin).select($"r.*"),
+ Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil)
+
+ // With GROUP BY and alias and additional fields in the struct
+ checkAnswer(sql(
+ """
+ | SELECT max(struct(a, record.*, b)) as r FROM
+ | (select a as a, b as b, struct(a,b) as record from testData2) tmp
+ | GROUP BY a
+ """.stripMargin).select($"r.*"),
+ Row(1, 1, 2, 2) :: Row(2, 2, 2, 2) :: Row(3, 3, 2, 2) :: Nil)
+
+ // Create a data set that contains nested structs.
+ val nestedStructData = sql(
+ """
+ | SELECT struct(r1, r2) as record FROM
+ | (SELECT struct(a, b) as r1, struct(b, a) as r2 FROM testData2) tmp
+ """.stripMargin)
+
+ checkAnswer(nestedStructData.select($"record.*"),
+ Row(Row(1, 1), Row(1, 1)) :: Row(Row(1, 2), Row(2, 1)) :: Row(Row(2, 1), Row(1, 2)) ::
+ Row(Row(2, 2), Row(2, 2)) :: Row(Row(3, 1), Row(1, 3)) :: Row(Row(3, 2), Row(2, 3)) :: Nil)
+ checkAnswer(nestedStructData.select($"record.r1"),
+ Row(Row(1, 1)) :: Row(Row(1, 2)) :: Row(Row(2, 1)) :: Row(Row(2, 2)) ::
+ Row(Row(3, 1)) :: Row(Row(3, 2)) :: Nil)
+ checkAnswer(
+ nestedStructData.select($"record.r1.*"),
+ Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil)
+
+ // Try with a registered table
+ nestedStructData.registerTempTable("nestedStructTable")
+ checkAnswer(sql("SELECT record.* FROM nestedStructTable"),
+ nestedStructData.select($"record.*"))
+ checkAnswer(sql("SELECT record.r1 FROM nestedStructTable"),
+ nestedStructData.select($"record.r1"))
+ checkAnswer(sql("SELECT record.r1.* FROM nestedStructTable"),
+ nestedStructData.select($"record.r1.*"))
+
+ // Create paths with unusual characters.
+ val specialCharacterPath = sql(
+ """
+ | SELECT struct(`col$.a_`, `a.b.c.`) as `r&&b.c` FROM
+ | (SELECT struct(a, b) as `col$.a_`, struct(b, a) as `a.b.c.` FROM testData2) tmp
+ """.stripMargin)
+ specialCharacterPath.registerTempTable("specialCharacterTable")
+ checkAnswer(specialCharacterPath.select($"`r&&b.c`.*"),
+ nestedStructData.select($"record.*"))
+ checkAnswer(sql("SELECT `r&&b.c`.`col$.a_` FROM specialCharacterTable"),
+ nestedStructData.select($"record.r1"))
+ checkAnswer(sql("SELECT `r&&b.c`.`a.b.c.` FROM specialCharacterTable"),
+ nestedStructData.select($"record.r2"))
+ checkAnswer(sql("SELECT `r&&b.c`.`col$.a_`.* FROM specialCharacterTable"),
+ nestedStructData.select($"record.r1.*"))
+
+ // Try star expanding a scalar. This should fail.
+ assert(intercept[AnalysisException](sql("select a.* from testData2")).getMessage.contains(
+ "Can only star expand struct data types."))
+
+ // Try resolving something not there.
+ assert(intercept[AnalysisException](sql("SELECT abc.* FROM nestedStructTable"))
+ .getMessage.contains("cannot resolve"))
+ }
+
+
+ test("Struct Star Expansion - Name conflict") {
+ // Create a data set that contains a naming conflict
+ val nameConflict = sql("SELECT struct(a, b) as nameConflict, a as a FROM testData2")
+ nameConflict.registerTempTable("nameConflict")
+ // Unqualified should resolve to table.
+ checkAnswer(sql("SELECT nameConflict.* FROM nameConflict"),
+ Row(Row(1, 1), 1) :: Row(Row(1, 2), 1) :: Row(Row(2, 1), 2) :: Row(Row(2, 2), 2) ::
+ Row(Row(3, 1), 3) :: Row(Row(3, 2), 3) :: Nil)
+ // Qualify the struct type with the table name.
+ checkAnswer(sql("SELECT nameConflict.nameConflict.* FROM nameConflict"),
+ Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil)
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 3697761f20..ab88c1e68f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -1505,7 +1505,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
// The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only
// has a single child which is tableName.
case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) =>
- UnresolvedStar(Some(name))
+ UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName(name)))
/* Aggregate Functions */
case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => Count(Literal(1))