aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgatorsmile <gatorsmile@gmail.com>2016-03-22 08:21:02 +0800
committerWenchen Fan <wenchen@databricks.com>2016-03-22 08:21:02 +0800
commit3f49e0766f3a369a44e14632de68c657773b7a27 (patch)
tree98a493a351dc476656eff031d4f97109ffeed0e0
parentb5f1ab701a167a728bb006e01b392b203da84391 (diff)
downloadspark-3f49e0766f3a369a44e14632de68c657773b7a27.tar.gz
spark-3f49e0766f3a369a44e14632de68c657773b7a27.tar.bz2
spark-3f49e0766f3a369a44e14632de68c657773b7a27.zip
[SPARK-13320][SQL] Support Star in CreateStruct/CreateArray and Error Handling when DataFrame/DataSet Functions using Star
This PR resolves two issues: First, expanding * inside aggregate functions of structs when using Dataframe/Dataset APIs. For example, ```scala structDf.groupBy($"a").agg(min(struct($"record.*"))) ``` Second, it improves the error messages when having invalid star usage when using Dataframe/Dataset APIs. For example, ```scala pagecounts4PartitionsDS .map(line => (line._1, line._3)) .toDF() .groupBy($"_1") .agg(sum("*") as "sumOccurances") ``` Before the fix, the invalid usage will issue a confusing error message, like: ``` org.apache.spark.sql.AnalysisException: cannot resolve '_1' given input columns _1, _2; ``` After the fix, the message is like: ``` org.apache.spark.sql.AnalysisException: Invalid usage of '*' in function 'sum' ``` cc: rxin nongli cloud-fan Author: gatorsmile <gatorsmile@gmail.com> Closes #11208 from gatorsmile/sumDataSetResolution.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala138
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala35
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala12
4 files changed, 113 insertions, 77 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 7a08c7dcfc..5951a70c48 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -80,6 +80,7 @@ class Analyzer(
EliminateUnions),
Batch("Resolution", fixedPoint,
ResolveRelations ::
+ ResolveStar ::
ResolveReferences ::
ResolveGroupingAnalytics ::
ResolvePivot ::
@@ -373,28 +374,83 @@ class Analyzer(
}
/**
- * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from
- * a logical plan node's children.
+ * Expand [[UnresolvedStar]] or [[ResolvedStar]] to the matching attributes in child's output.
*/
- object ResolveReferences extends Rule[LogicalPlan] {
+ object ResolveStar extends Rule[LogicalPlan] {
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case p: LogicalPlan if !p.childrenResolved => p
+
+ // If the projection list contains Stars, expand it.
+ case p: Project if containsStar(p.projectList) =>
+ val expanded = p.projectList.flatMap {
+ case s: Star => s.expand(p.child, resolver)
+ case ua @ UnresolvedAlias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) =>
+ UnresolvedAlias(child = expandStarExpression(ua.child, p.child)) :: Nil
+ case a @ Alias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) =>
+ a.withNewChildren(expandStarExpression(a.child, p.child) :: Nil)
+ .asInstanceOf[Alias] :: Nil
+ case o => o :: Nil
+ }
+ Project(projectList = expanded, p.child)
+ // If the aggregate function argument contains Stars, expand it.
+ case a: Aggregate if containsStar(a.aggregateExpressions) =>
+ val expanded = a.aggregateExpressions.flatMap {
+ case s: Star => s.expand(a.child, resolver)
+ case o if containsStar(o :: Nil) => expandStarExpression(o, a.child) :: Nil
+ case o => o :: Nil
+ }.map(_.asInstanceOf[NamedExpression])
+ a.copy(aggregateExpressions = expanded)
+ // If the script transformation input contains Stars, expand it.
+ case t: ScriptTransformation if containsStar(t.input) =>
+ t.copy(
+ input = t.input.flatMap {
+ case s: Star => s.expand(t.child, resolver)
+ case o => o :: Nil
+ }
+ )
+ case g: Generate if containsStar(g.generator.children) =>
+ failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF")
+ }
+
+ /**
+ * Returns true if `exprs` contains a [[Star]].
+ */
+ def containsStar(exprs: Seq[Expression]): Boolean =
+ exprs.exists(_.collect { case _: Star => true }.nonEmpty)
+
/**
- * Foreach expression, expands the matching attribute.*'s in `child`'s input for the subtree
- * rooted at each expression.
+ * Expands the matching attribute.*'s in `child`'s output.
*/
- def expandStarExpressions(exprs: Seq[Expression], child: LogicalPlan): Seq[Expression] = {
- exprs.flatMap {
- case s: Star => s.expand(child, resolver)
- case e =>
- e.transformDown {
- case f1: UnresolvedFunction if containsStar(f1.children) =>
- f1.copy(children = f1.children.flatMap {
- case s: Star => s.expand(child, resolver)
- case o => o :: Nil
- })
- } :: Nil
+ def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = {
+ expr.transformUp {
+ case f1: UnresolvedFunction if containsStar(f1.children) =>
+ f1.copy(children = f1.children.flatMap {
+ case s: Star => s.expand(child, resolver)
+ case o => o :: Nil
+ })
+ case c: CreateStruct if containsStar(c.children) =>
+ c.copy(children = c.children.flatMap {
+ case s: Star => s.expand(child, resolver)
+ case o => o :: Nil
+ })
+ case c: CreateArray if containsStar(c.children) =>
+ c.copy(children = c.children.flatMap {
+ case s: Star => s.expand(child, resolver)
+ case o => o :: Nil
+ })
+ // count(*) has been replaced by count(1)
+ case o if containsStar(o.children) =>
+ failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'")
}
}
+ }
+ /**
+ * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from
+ * a logical plan node's children.
+ */
+ object ResolveReferences extends Rule[LogicalPlan] {
/**
* Generate a new logical plan for the right child with different expression IDs
* for all conflicting attributes.
@@ -456,48 +512,6 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p: LogicalPlan if !p.childrenResolved => p
- // If the projection list contains Stars, expand it.
- case p @ Project(projectList, child) if containsStar(projectList) =>
- Project(
- projectList.flatMap {
- case s: Star => s.expand(child, resolver)
- case UnresolvedAlias(f @ UnresolvedFunction(_, args, _), _) if containsStar(args) =>
- val newChildren = expandStarExpressions(args, child)
- UnresolvedAlias(child = f.copy(children = newChildren)) :: Nil
- case a @ Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) =>
- val newChildren = expandStarExpressions(args, child)
- Alias(child = f.copy(children = newChildren), name)(
- isGenerated = a.isGenerated) :: Nil
- case UnresolvedAlias(c @ CreateArray(args), _) if containsStar(args) =>
- val expandedArgs = args.flatMap {
- case s: Star => s.expand(child, resolver)
- case o => o :: Nil
- }
- UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
- case UnresolvedAlias(c @ CreateStruct(args), _) if containsStar(args) =>
- val expandedArgs = args.flatMap {
- case s: Star => s.expand(child, resolver)
- case o => o :: Nil
- }
- UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
- case o => o :: Nil
- },
- child)
-
- case t: ScriptTransformation if containsStar(t.input) =>
- t.copy(
- input = t.input.flatMap {
- case s: Star => s.expand(t.child, resolver)
- case o => o :: Nil
- }
- )
-
- // If the aggregate function argument contains Stars, expand it.
- case a: Aggregate if containsStar(a.aggregateExpressions) =>
- val expanded = expandStarExpressions(a.aggregateExpressions, a.child)
- .map(_.asInstanceOf[NamedExpression])
- a.copy(aggregateExpressions = expanded)
-
// To resolve duplicate expression IDs for Join and Intersect
case j @ Join(left, right, _, _) if !j.duplicateResolved =>
j.copy(right = dedupRight(left, right))
@@ -592,12 +606,6 @@ class Analyzer(
def findAliases(projectList: Seq[NamedExpression]): AttributeSet = {
AttributeSet(projectList.collect { case a: Alias => a.toAttribute })
}
-
- /**
- * Returns true if `exprs` contains a [[Star]].
- */
- def containsStar(exprs: Seq[Expression]): Boolean =
- exprs.exists(_.collect { case _: Star => true }.nonEmpty)
}
protected[sql] def resolveExpression(
@@ -923,8 +931,6 @@ class Analyzer(
*/
object ResolveGenerate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
- case g: Generate if ResolveReferences.containsStar(g.generator.children) =>
- failAnalysis("Cannot explode *, explode can only be applied on a specific column.")
case p: Generate if !p.child.resolved || !p.generator.resolved => p
case g: Generate if !g.resolved =>
g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name)))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 1b297525bd..c87a2e24bd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -190,6 +190,11 @@ class AnalysisErrorSuite extends AnalysisTest {
"cannot resolve" :: "havingCondition" :: Nil)
errorTest(
+ "unresolved star expansion in max",
+ testRelation2.groupBy('a)(sum(UnresolvedStar(None))),
+ "Invalid usage of '*'" :: "in expression 'sum'" :: Nil)
+
+ errorTest(
"bad casts",
testRelation.select(Literal(1).cast(BinaryType).as('badCast)),
"cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index df2a287030..d03597ee5d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -166,22 +166,43 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
)
}
- test("SPARK-8930: explode should fail with a meaningful message if it takes a star") {
+ test("Star Expansion - CreateStruct and CreateArray") {
+ val structDf = testData2.select("a", "b").as("record")
+ // CreateStruct and CreateArray in aggregateExpressions
+ assert(structDf.groupBy($"a").agg(min(struct($"record.*"))).first() == Row(3, Row(3, 1)))
+ assert(structDf.groupBy($"a").agg(min(array($"record.*"))).first() == Row(3, Seq(3, 1)))
+
+ // CreateStruct and CreateArray in project list (unresolved alias)
+ assert(structDf.select(struct($"record.*")).first() == Row(Row(1, 1)))
+ assert(structDf.select(array($"record.*")).first().getAs[Seq[Int]](0) === Seq(1, 1))
+
+ // CreateStruct and CreateArray in project list (alias)
+ assert(structDf.select(struct($"record.*").as("a")).first() == Row(Row(1, 1)))
+ assert(structDf.select(array($"record.*").as("a")).first().getAs[Seq[Int]](0) === Seq(1, 1))
+ }
+
+ test("Star Expansion - explode should fail with a meaningful message if it takes a star") {
val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv")
val e = intercept[AnalysisException] {
df.explode($"*") { case Row(prefix: String, csv: String) =>
csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq
}.queryExecution.assertAnalyzed()
}
- assert(e.getMessage.contains(
- "Cannot explode *, explode can only be applied on a specific column."))
+ assert(e.getMessage.contains("Invalid usage of '*' in explode/json_tuple/UDTF"))
- df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) =>
- csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq
- }.queryExecution.assertAnalyzed()
+ checkAnswer(
+ df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) =>
+ csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq
+ },
+ Row("1", "1,2", "1:1") ::
+ Row("1", "1,2", "1:2") ::
+ Row("2", "4", "2:4") ::
+ Row("3", "7,8,9", "3:7") ::
+ Row("3", "7,8,9", "3:8") ::
+ Row("3", "7,8,9", "3:9") :: Nil)
}
- test("explode alias and star") {
+ test("Star Expansion - explode alias and star") {
val df = Seq((Array("a"), 1)).toDF("a", "b")
checkAnswer(
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 6bedead91b..2806b87f33 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -738,20 +738,24 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
.queryExecution.analyzed
}
+ test("Star Expansion - script transform") {
+ val data = (1 to 100000).map { i => (i, i, i) }
+ data.toDF("d1", "d2", "d3").registerTempTable("script_trans")
+ assert(100000 === sql("SELECT TRANSFORM (*) USING 'cat' FROM script_trans").count())
+ }
+
test("test script transform for stdout") {
val data = (1 to 100000).map { i => (i, i, i) }
data.toDF("d1", "d2", "d3").registerTempTable("script_trans")
assert(100000 ===
- sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans")
- .queryExecution.toRdd.count())
+ sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans").count())
}
test("test script transform for stderr") {
val data = (1 to 100000).map { i => (i, i, i) }
data.toDF("d1", "d2", "d3").registerTempTable("script_trans")
assert(0 ===
- sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat 1>&2' AS (a,b,c) FROM script_trans")
- .queryExecution.toRdd.count())
+ sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat 1>&2' AS (a,b,c) FROM script_trans").count())
}
test("test script transform data type") {