aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDilip Biswal <dbiswal@us.ibm.com>2015-12-18 09:54:30 -0800
committerYin Huai <yhuai@databricks.com>2015-12-18 09:54:30 -0800
commitee444fe4b8c9f382524e1fa346c67ba6da8104d8 (patch)
treec231aae992e2de9bf96e75364d940657c5aa84cb
parent2782818287a71925523c1320291db6cb25221e9f (diff)
downloadspark-ee444fe4b8c9f382524e1fa346c67ba6da8104d8.tar.gz
spark-ee444fe4b8c9f382524e1fa346c67ba6da8104d8.tar.bz2
spark-ee444fe4b8c9f382524e1fa346c67ba6da8104d8.zip
[SPARK-11619][SQL] cannot use UDTF in DataFrame.selectExpr
Description of the problem from cloud-fan Actually this line: https://github.com/apache/spark/blob/branch-1.5/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala#L689 When we use `selectExpr`, we pass in `UnresolvedFunction` to `DataFrame.select` and fall in the last case. A workaround is to do special handling for UDTF like we did for `explode`(and `json_tuple` in 1.6), wrap it with `MultiAlias`. Another workaround is using `expr`, for example, `df.select(expr("explode(a)").as(Nil))`, I think `selectExpr` is no longer needed after we have the `expr` function.... Author: Dilip Biswal <dbiswal@us.ibm.com> Closes #9981 from dilipbiswal/spark-11619.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala4
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala2
7 files changed, 31 insertions, 14 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 64dd83a915..c396546b4c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -149,12 +149,12 @@ class Analyzer(
exprs.zipWithIndex.map {
case (expr, i) =>
expr transform {
- case u @ UnresolvedAlias(child) => child match {
+ case u @ UnresolvedAlias(child, optionalAliasName) => child match {
case ne: NamedExpression => ne
case e if !e.resolved => u
case g: Generator => MultiAlias(g, Nil)
case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)()
- case other => Alias(other, s"_c$i")()
+ case other => Alias(other, optionalAliasName.getOrElse(s"_c$i"))()
}
}
}.asInstanceOf[Seq[NamedExpression]]
@@ -287,7 +287,7 @@ class Analyzer(
}
}
val newGroupByExprs = groupByExprs.map {
- case UnresolvedAlias(e) => e
+ case UnresolvedAlias(e, _) => e
case e => e
}
Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child)
@@ -352,19 +352,19 @@ class Analyzer(
Project(
projectList.flatMap {
case s: Star => s.expand(child, resolver)
- case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if containsStar(args) =>
+ case UnresolvedAlias(f @ UnresolvedFunction(_, args, _), _) if containsStar(args) =>
val newChildren = expandStarExpressions(args, child)
UnresolvedAlias(child = f.copy(children = newChildren)) :: Nil
case Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) =>
val newChildren = expandStarExpressions(args, child)
Alias(child = f.copy(children = newChildren), name)() :: Nil
- case UnresolvedAlias(c @ CreateArray(args)) if containsStar(args) =>
+ case UnresolvedAlias(c @ CreateArray(args), _) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
}
UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
- case UnresolvedAlias(c @ CreateStruct(args)) if containsStar(args) =>
+ case UnresolvedAlias(c @ CreateStruct(args), _) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 4f89b462a6..64cad6ee78 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -284,8 +284,12 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression)
/**
* Holds the expression that has yet to be aliased.
+ *
+ * @param child The computation that is needs to be resolved during analysis.
+ * @param aliasName The name if specified to be asoosicated with the result of computing [[child]]
+ *
*/
-case class UnresolvedAlias(child: Expression)
+case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None)
extends UnaryExpression with NamedExpression with Unevaluable {
override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 297ef2299c..5026c0d6d1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -17,20 +17,19 @@
package org.apache.spark.sql
-import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
-
import scala.language.implicitConversions
-import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
-import org.apache.spark.sql.functions.lit
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.catalyst.SqlParser._
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DataTypeParser
+import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
+import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.types._
-
private[sql] object Column {
def apply(colName: String): Column = new Column(colName)
@@ -130,8 +129,11 @@ class Column(protected[sql] val expr: Expression) extends Logging {
// Leave an unaliased generator with an empty list of names since the analyzer will generate
// the correct defaults after the nested expression's type has been resolved.
case explode: Explode => MultiAlias(explode, Nil)
+
case jt: JsonTuple => MultiAlias(jt, Nil)
+ case func: UnresolvedFunction => UnresolvedAlias(func, Some(func.prettyString))
+
case expr: Expression => Alias(expr, expr.prettyString)()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 79b4244ac0..d201d65238 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -450,7 +450,7 @@ class Dataset[T] private[sql](
*/
@scala.annotation.varargs
def groupBy(cols: Column*): GroupedDataset[Row, T] = {
- val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias)
+ val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_))
val withKey = Project(withKeyColumns, logicalPlan)
val executed = sqlContext.executePlan(withKey)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 0644bdaaa3..4c3e12af72 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -176,6 +176,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
testData.select("key").collect().toSeq)
}
+ test("selectExpr with udtf") {
+ val df = Seq((Map("1" -> 1), 1)).toDF("a", "b")
+ checkAnswer(
+ df.selectExpr("explode(a)"),
+ Row("1", 1) :: Nil)
+ }
+
test("filterExpr") {
val res = testData.collect().filter(_.getInt(0) > 90).toSeq
checkAnswer(testData.filter("key > 90"), res)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
index 1f384edf32..1391c9d57f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
@@ -73,6 +73,10 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
checkAnswer(
df.select($"key", functions.json_tuple($"jstring", "f1", "f2", "f3", "f4", "f5")),
expected)
+
+ checkAnswer(
+ df.selectExpr("key", "json_tuple(jstring, 'f1', 'f2', 'f3', 'f4', 'f5')"),
+ expected)
}
test("json_tuple filter and group") {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index da41b659e3..0e89928cb6 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -1107,7 +1107,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
// (if there is a group by) or a script transformation.
val withProject: LogicalPlan = transformation.getOrElse {
val selectExpressions =
- select.getChildren.asScala.flatMap(selExprNodeToExpr).map(UnresolvedAlias)
+ select.getChildren.asScala.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_))
Seq(
groupByClause.map(e => e match {
case Token("TOK_GROUPBY", children) =>