aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgatorsmile <gatorsmile@gmail.com>2016-03-24 11:13:36 +0800
committerWenchen Fan <wenchen@databricks.com>2016-03-24 11:13:36 +0800
commitf42eaf42bdca8bc6f390f1f31ee60faa1662489b (patch)
treed5f84d22eb06d95fd4b20a79a6ccaaa34c911d89
parentde4e48b62b998d45d4a749234741a45534719497 (diff)
downloadspark-f42eaf42bdca8bc6f390f1f31ee60faa1662489b.tar.gz
spark-f42eaf42bdca8bc6f390f1f31ee60faa1662489b.tar.bz2
spark-f42eaf42bdca8bc6f390f1f31ee60faa1662489b.zip
[SPARK-14085][SQL] Star Expansion for Hash
#### What changes were proposed in this pull request? This PR is to support star expansion in hash. For example, ```SQL val structDf = testData2.select("a", "b").as("record") structDf.select(hash($"*") ``` In addition, it refactors the codes for the rule `ResolveStar` and fixes a regression for star expansion in group by when using SQL API. For example, ```SQL SELECT * FROM testData2 group by a, b ``` cc cloud-fan Now, the code for star resolution is much cleaner. The coverage is better. Could you check if this refactoring is good? Thanks! #### How was this patch tested? Added a few test cases to cover it. Author: gatorsmile <gatorsmile@gmail.com> Closes #11904 from gatorsmile/starResolution.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala40
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala8
3 files changed, 50 insertions, 17 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 178e9402fa..54543eebb7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -380,27 +380,12 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p: LogicalPlan if !p.childrenResolved => p
-
// If the projection list contains Stars, expand it.
case p: Project if containsStar(p.projectList) =>
- val expanded = p.projectList.flatMap {
- case s: Star => s.expand(p.child, resolver)
- case ua @ UnresolvedAlias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) =>
- UnresolvedAlias(child = expandStarExpression(ua.child, p.child)) :: Nil
- case a @ Alias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) =>
- a.withNewChildren(expandStarExpression(a.child, p.child) :: Nil)
- .asInstanceOf[Alias] :: Nil
- case o => o :: Nil
- }
- Project(projectList = expanded, p.child)
+ p.copy(projectList = buildExpandedProjectList(p.projectList, p.child))
// If the aggregate function argument contains Stars, expand it.
case a: Aggregate if containsStar(a.aggregateExpressions) =>
- val expanded = a.aggregateExpressions.flatMap {
- case s: Star => s.expand(a.child, resolver)
- case o if containsStar(o :: Nil) => expandStarExpression(o, a.child) :: Nil
- case o => o :: Nil
- }.map(_.asInstanceOf[NamedExpression])
- a.copy(aggregateExpressions = expanded)
+ a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child))
// If the script transformation input contains Stars, expand it.
case t: ScriptTransformation if containsStar(t.input) =>
t.copy(
@@ -414,6 +399,22 @@ class Analyzer(
}
/**
+ * Build a project list for Project/Aggregate and expand the star if possible
+ */
+ private def buildExpandedProjectList(
+ exprs: Seq[NamedExpression],
+ child: LogicalPlan): Seq[NamedExpression] = {
+ exprs.flatMap {
+ // Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*")
+ case s: Star => s.expand(child, resolver)
+ // Using SQL API without running ResolveAlias: SELECT * FROM testData2 group by a, b
+ case UnresolvedAlias(s: Star, _) => s.expand(child, resolver)
+ case o if containsStar(o :: Nil) => expandStarExpression(o, child) :: Nil
+ case o => o :: Nil
+ }.map(_.asInstanceOf[NamedExpression])
+ }
+
+ /**
* Returns true if `exprs` contains a [[Star]].
*/
def containsStar(exprs: Seq[Expression]): Boolean =
@@ -439,6 +440,11 @@ class Analyzer(
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
+ case p: Murmur3Hash if containsStar(p.children) =>
+ p.copy(children = p.children.flatMap {
+ case s: Star => s.expand(child, resolver)
+ case o => o :: Nil
+ })
// count(*) has been replaced by count(1)
case o if containsStar(o.children) =>
failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index e6b7dc9199..ec4e7b2042 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -181,6 +181,25 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
assert(structDf.select(array($"record.*").as("a")).first().getAs[Seq[Int]](0) === Seq(1, 1))
}
+ test("Star Expansion - hash") {
+ val structDf = testData2.select("a", "b").as("record")
+ checkAnswer(
+ structDf.groupBy($"a", $"b").agg(min(hash($"a", $"*"))),
+ structDf.groupBy($"a", $"b").agg(min(hash($"a", $"a", $"b"))))
+
+ checkAnswer(
+ structDf.groupBy($"a", $"b").agg(hash($"a", $"*")),
+ structDf.groupBy($"a", $"b").agg(hash($"a", $"a", $"b")))
+
+ checkAnswer(
+ structDf.select(hash($"*")),
+ structDf.select(hash($"record.*")))
+
+ checkAnswer(
+ structDf.select(hash($"a", $"*")),
+ structDf.select(hash($"a", $"record.*")))
+ }
+
test("Star Expansion - explode should fail with a meaningful message if it takes a star") {
val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv")
val e = intercept[AnalysisException] {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index bd13474e73..4f36b1b42a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1941,6 +1941,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}
+ test("Star Expansion - group by") {
+ withSQLConf("spark.sql.retainGroupColumns" -> "false") {
+ checkAnswer(
+ testData2.groupBy($"a", $"b").agg($"*"),
+ sql("SELECT * FROM testData2 group by a, b"))
+ }
+ }
+
test("Common subexpression elimination") {
// TODO: support subexpression elimination in whole stage codegen
withSQLConf("spark.sql.codegen.wholeStage" -> "false") {