From ee444fe4b8c9f382524e1fa346c67ba6da8104d8 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 18 Dec 2015 09:54:30 -0800 Subject: [SPARK-11619][SQL] cannot use UDTF in DataFrame.selectExpr Description of the problem from cloud-fan Actually this line: https://github.com/apache/spark/blob/branch-1.5/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala#L689 When we use `selectExpr`, we pass in `UnresolvedFunction` to `DataFrame.select` and fall in the last case. A workaround is to do special handling for UDTF like we did for `explode`(and `json_tuple` in 1.6), wrap it with `MultiAlias`. Another workaround is using `expr`, for example, `df.select(expr("explode(a)").as(Nil))`, I think `selectExpr` is no longer needed after we have the `expr` function.... Author: Dilip Biswal Closes #9981 from dilipbiswal/spark-11619. --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 12 ++++++------ .../org/apache/spark/sql/catalyst/analysis/unresolved.scala | 6 +++++- sql/core/src/main/scala/org/apache/spark/sql/Column.scala | 12 +++++++----- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 2 +- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 7 +++++++ .../test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala | 4 ++++ .../src/main/scala/org/apache/spark/sql/hive/HiveQl.scala | 2 +- 7 files changed, 31 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 64dd83a915..c396546b4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -149,12 +149,12 @@ class Analyzer( exprs.zipWithIndex.map { case (expr, i) => expr transform { - case u @ UnresolvedAlias(child) => child match { + case u @ UnresolvedAlias(child, optionalAliasName) => child match { case ne: NamedExpression => ne case e if !e.resolved => u case g: Generator => MultiAlias(g, Nil) case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)() - case other => Alias(other, s"_c$i")() + case other => Alias(other, optionalAliasName.getOrElse(s"_c$i"))() } } }.asInstanceOf[Seq[NamedExpression]] @@ -287,7 +287,7 @@ class Analyzer( } } val newGroupByExprs = groupByExprs.map { - case UnresolvedAlias(e) => e + case UnresolvedAlias(e, _) => e case e => e } Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child) @@ -352,19 +352,19 @@ class Analyzer( Project( projectList.flatMap { case s: Star => s.expand(child, resolver) - case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if containsStar(args) => + case UnresolvedAlias(f @ UnresolvedFunction(_, args, _), _) if containsStar(args) => val newChildren = expandStarExpressions(args, child) UnresolvedAlias(child = f.copy(children = newChildren)) :: Nil case Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) => val newChildren = expandStarExpressions(args, child) Alias(child = f.copy(children = newChildren), name)() :: Nil - case UnresolvedAlias(c @ CreateArray(args)) if containsStar(args) => + case UnresolvedAlias(c @ CreateArray(args), _) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child, resolver) case o => o :: Nil } UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil - case UnresolvedAlias(c @ CreateStruct(args)) if containsStar(args) => + case UnresolvedAlias(c @ CreateStruct(args), _) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child, resolver) case o => o :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 4f89b462a6..64cad6ee78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -284,8 +284,12 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression) /** * Holds the expression that has yet to be aliased. + * + * @param child The computation that is needs to be resolved during analysis. + * @param aliasName The name if specified to be asoosicated with the result of computing [[child]] + * */ -case class UnresolvedAlias(child: Expression) +case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None) extends UnaryExpression with NamedExpression with Unevaluable { override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 297ef2299c..5026c0d6d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -17,20 +17,19 @@ package org.apache.spark.sql -import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression - import scala.language.implicitConversions -import org.apache.spark.annotation.Experimental import org.apache.spark.Logging -import org.apache.spark.sql.functions.lit +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.SqlParser._ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DataTypeParser +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types._ - private[sql] object Column { def apply(colName: String): Column = new Column(colName) @@ -130,8 +129,11 @@ class Column(protected[sql] val expr: Expression) extends Logging { // Leave an unaliased generator with an empty list of names since the analyzer will generate // the correct defaults after the nested expression's type has been resolved. case explode: Explode => MultiAlias(explode, Nil) + case jt: JsonTuple => MultiAlias(jt, Nil) + case func: UnresolvedFunction => UnresolvedAlias(func, Some(func.prettyString)) + case expr: Expression => Alias(expr, expr.prettyString)() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 79b4244ac0..d201d65238 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -450,7 +450,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def groupBy(cols: Column*): GroupedDataset[Row, T] = { - val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias) + val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_)) val withKey = Project(withKeyColumns, logicalPlan) val executed = sqlContext.executePlan(withKey) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 0644bdaaa3..4c3e12af72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -176,6 +176,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.select("key").collect().toSeq) } + test("selectExpr with udtf") { + val df = Seq((Map("1" -> 1), 1)).toDF("a", "b") + checkAnswer( + df.selectExpr("explode(a)"), + Row("1", 1) :: Nil) + } + test("filterExpr") { val res = testData.collect().filter(_.getInt(0) > 90).toSeq checkAnswer(testData.filter("key > 90"), res) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 1f384edf32..1391c9d57f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -73,6 +73,10 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer( df.select($"key", functions.json_tuple($"jstring", "f1", "f2", "f3", "f4", "f5")), expected) + + checkAnswer( + df.selectExpr("key", "json_tuple(jstring, 'f1', 'f2', 'f3', 'f4', 'f5')"), + expected) } test("json_tuple filter and group") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index da41b659e3..0e89928cb6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1107,7 +1107,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // (if there is a group by) or a script transformation. val withProject: LogicalPlan = transformation.getOrElse { val selectExpressions = - select.getChildren.asScala.flatMap(selExprNodeToExpr).map(UnresolvedAlias) + select.getChildren.asScala.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_)) Seq( groupByClause.map(e => e match { case Token("TOK_GROUPBY", children) => -- cgit v1.2.3