aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2016-02-21 22:53:15 +0800
committerCheng Lian <lian@databricks.com>2016-02-21 22:53:15 +0800
commitd9efe63ecdc60a9955f1924de0e8a00bcb6a559d (patch)
tree218879fda9e2285db67a72ed9fc42ec3ab61afa5
parentd806ed34365aa27895547297fff4cc48ecbeacdf (diff)
downloadspark-d9efe63ecdc60a9955f1924de0e8a00bcb6a559d.tar.gz
spark-d9efe63ecdc60a9955f1924de0e8a00bcb6a559d.tar.bz2
spark-d9efe63ecdc60a9955f1924de0e8a00bcb6a559d.zip
[SPARK-12799] Simplify various string output for expressions
This PR introduces several major changes: 1. Replacing `Expression.prettyString` with `Expression.sql` The `prettyString` method is mostly an internal, developer faced facility for debugging purposes, and shouldn't be exposed to users. 1. Using SQL-like representation as column names for selected fields that are not named expression (back-ticks and double quotes should be removed) Before, we were using `prettyString` as column names when possible, and sometimes the result column names can be weird. Here are several examples: Expression | `prettyString` | `sql` | Note ------------------ | -------------- | ---------- | --------------- `a && b` | `a && b` | `a AND b` | `a.getField("f")` | `a[f]` | `a.f` | `a` is a struct 1. Adding trait `NonSQLExpression` extending from `Expression` for expressions that don't have a SQL representation (e.g. Scala UDF/UDAF and Java/Scala object expressions used for encoders) `NonSQLExpression.sql` may return an arbitrary user facing string representation of the expression. Author: Cheng Lian <lian@databricks.com> Closes #10757 from liancheng/spark-12799.simplify-expression-string-methods.
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R4
-rw-r--r--python/pyspark/sql/column.py32
-rw-r--r--python/pyspark/sql/context.py10
-rw-r--r--python/pyspark/sql/functions.py30
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala26
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala17
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala71
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala22
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala39
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala20
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala25
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala37
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala26
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala15
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala4
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala22
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala43
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala15
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala64
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala22
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala5
49 files changed, 402 insertions, 278 deletions
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 7b5713720d..cc118108f6 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -1047,13 +1047,13 @@ test_that("column functions", {
schema = c("a", "b", "c"))
result <- collect(select(df, struct("a", "c")))
expected <- data.frame(row.names = 1:2)
- expected$"struct(a,c)" <- list(listToStruct(list(a = 1L, c = 3L)),
+ expected$"struct(a, c)" <- list(listToStruct(list(a = 1L, c = 3L)),
listToStruct(list(a = 4L, c = 6L)))
expect_equal(result, expected)
result <- collect(select(df, struct(df$a, df$b)))
expected <- data.frame(row.names = 1:2)
- expected$"struct(a,b)" <- list(listToStruct(list(a = 1L, b = 2L)),
+ expected$"struct(a, b)" <- list(listToStruct(list(a = 1L, b = 2L)),
listToStruct(list(a = 4L, b = 5L)))
expect_equal(result, expected)
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index 320451c52c..3866a49c0b 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -219,17 +219,17 @@ class Column(object):
>>> from pyspark.sql import Row
>>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF()
>>> df.select(df.r.getField("b")).show()
- +----+
- |r[b]|
- +----+
- | b|
- +----+
+ +---+
+ |r.b|
+ +---+
+ | b|
+ +---+
>>> df.select(df.r.a).show()
- +----+
- |r[a]|
- +----+
- | 1|
- +----+
+ +---+
+ |r.a|
+ +---+
+ | 1|
+ +---+
"""
return self[name]
@@ -346,12 +346,12 @@ class Column(object):
expression is between the given columns.
>>> df.select(df.name, df.age.between(2, 4)).show()
- +-----+--------------------------+
- | name|((age >= 2) && (age <= 4))|
- +-----+--------------------------+
- |Alice| true|
- | Bob| false|
- +-----+--------------------------+
+ +-----+---------------------------+
+ | name|((age >= 2) AND (age <= 4))|
+ +-----+---------------------------+
+ |Alice| true|
+ | Bob| false|
+ +-----+---------------------------+
"""
return (self >= lowerBound) & (self <= upperBound)
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 7f6fb410ab..89bf1443a6 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -92,8 +92,8 @@ class SQLContext(object):
>>> df.registerTempTable("allTypes")
>>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
- [Row(_c0=2, _c1=2.0, _c2=False, _c3=2, _c4=0, \
- time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
+ [Row((i + CAST(1 AS BIGINT))=2, (d + CAST(1 AS DOUBLE))=2.0, (NOT b)=False, list[1]=2, \
+ dict[s]=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
>>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect()
[(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
"""
@@ -210,17 +210,17 @@ class SQLContext(object):
>>> sqlContext.registerFunction("stringLengthString", lambda x: len(x))
>>> sqlContext.sql("SELECT stringLengthString('test')").collect()
- [Row(_c0=u'4')]
+ [Row(stringLengthString(test)=u'4')]
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
- [Row(_c0=4)]
+ [Row(stringLengthInt(test)=4)]
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
- [Row(_c0=4)]
+ [Row(stringLengthInt(test)=4)]
"""
udf = UserDefinedFunction(f, returnType, name)
self._ssql_ctx.udf().registerPython(name, udf._judf)
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 5fc1cc2cae..fdae05d98c 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -223,22 +223,22 @@ def coalesce(*cols):
+----+----+
>>> cDf.select(coalesce(cDf["a"], cDf["b"])).show()
- +-------------+
- |coalesce(a,b)|
- +-------------+
- | null|
- | 1|
- | 2|
- +-------------+
+ +--------------+
+ |coalesce(a, b)|
+ +--------------+
+ | null|
+ | 1|
+ | 2|
+ +--------------+
>>> cDf.select('*', coalesce(cDf["a"], lit(0.0))).show()
- +----+----+---------------+
- | a| b|coalesce(a,0.0)|
- +----+----+---------------+
- |null|null| 0.0|
- | 1|null| 1.0|
- |null| 2| 0.0|
- +----+----+---------------+
+ +----+----+----------------+
+ | a| b|coalesce(a, 0.0)|
+ +----+----+----------------+
+ |null|null| 0.0|
+ | 1|null| 1.0|
+ |null| 2| 0.0|
+ +----+----+----------------+
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.coalesce(_to_seq(sc, cols, _to_java_column))
@@ -1528,7 +1528,7 @@ def array_contains(col, value):
>>> df = sqlContext.createDataFrame([(["a", "b", "c"],), ([],)], ['data'])
>>> df.select(array_contains(df.data, "a")).collect()
- [Row(array_contains(data,a)=True), Row(array_contains(data,a)=False)]
+ [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index ba306f8b32..e153f4dd2b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
+import org.apache.spark.sql.catalyst.util.usePrettyExpression
import org.apache.spark.sql.types._
/**
@@ -165,7 +166,8 @@ class Analyzer(
case e if !e.resolved => u
case g: Generator => MultiAlias(g, Nil)
case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)()
- case other => Alias(other, optionalAliasName.getOrElse(s"_c$i"))()
+ case e: ExtractValue => Alias(e, usePrettyExpression(e).sql)()
+ case e => Alias(e, optionalAliasName.getOrElse(usePrettyExpression(e).sql))()
}
}
}.asInstanceOf[Seq[NamedExpression]]
@@ -328,7 +330,7 @@ class Analyzer(
throw new AnalysisException(
s"Aggregate expression required for pivot, found '$aggregate'")
}
- val name = if (singleAgg) value.toString else value + "_" + aggregate.prettyString
+ val name = if (singleAgg) value.toString else value + "_" + aggregate.sql
Alias(filteredAggregate, name)()
}
}
@@ -1456,7 +1458,7 @@ object CleanupAliases extends Rule[LogicalPlan] {
*/
object ResolveUpCast extends Rule[LogicalPlan] {
private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
- throw new AnalysisException(s"Cannot up cast `${from.prettyString}` from " +
+ throw new AnalysisException(s"Cannot up cast ${from.sql} from " +
s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" +
"The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
"You can either add an explicit cast to the input data or choose a higher precision " +
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index fe053b9a0b..1e430c1fbb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -57,13 +57,13 @@ trait CheckAnalysis {
operator transformExpressionsUp {
case a: Attribute if !a.resolved =>
val from = operator.inputSet.map(_.name).mkString(", ")
- a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns: [$from]")
+ a.failAnalysis(s"cannot resolve '${a.sql}' given input columns: [$from]")
case e: Expression if e.checkInputDataTypes().isFailure =>
e.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckFailure(message) =>
e.failAnalysis(
- s"cannot resolve '${e.prettyString}' due to data type mismatch: $message")
+ s"cannot resolve '${e.sql}' due to data type mismatch: $message")
}
case c: Cast if !c.resolved =>
@@ -106,12 +106,12 @@ trait CheckAnalysis {
operator match {
case f: Filter if f.condition.dataType != BooleanType =>
failAnalysis(
- s"filter expression '${f.condition.prettyString}' " +
+ s"filter expression '${f.condition.sql}' " +
s"of type ${f.condition.dataType.simpleString} is not a boolean.")
case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType =>
failAnalysis(
- s"join condition '${condition.prettyString}' " +
+ s"join condition '${condition.sql}' " +
s"of type ${condition.dataType.simpleString} is not a boolean.")
case j @ Join(_, _, _, Some(condition)) =>
@@ -119,10 +119,10 @@ trait CheckAnalysis {
case p: Predicate =>
p.asInstanceOf[Expression].children.foreach(checkValidJoinConditionExprs)
case e if e.dataType.isInstanceOf[BinaryType] =>
- failAnalysis(s"binary type expression ${e.prettyString} cannot be used " +
+ failAnalysis(s"binary type expression ${e.sql} cannot be used " +
"in join conditions")
case e if e.dataType.isInstanceOf[MapType] =>
- failAnalysis(s"map type expression ${e.prettyString} cannot be used " +
+ failAnalysis(s"map type expression ${e.sql} cannot be used " +
"in join conditions")
case _ => // OK
}
@@ -144,13 +144,13 @@ trait CheckAnalysis {
if (!child.deterministic) {
failAnalysis(
- s"nondeterministic expression ${expr.prettyString} should not " +
+ s"nondeterministic expression ${expr.sql} should not " +
s"appear in the arguments of an aggregate function.")
}
}
case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) =>
failAnalysis(
- s"expression '${e.prettyString}' is neither present in the group by, " +
+ s"expression '${e.sql}' is neither present in the group by, " +
s"nor is it an aggregate function. " +
"Add to group by or wrap in first() (or first_value) if you don't care " +
"which value you get.")
@@ -163,7 +163,7 @@ trait CheckAnalysis {
// Check if the data type of expr is orderable.
if (!RowOrdering.isOrderable(expr.dataType)) {
failAnalysis(
- s"expression ${expr.prettyString} cannot be used as a grouping expression " +
+ s"expression ${expr.sql} cannot be used as a grouping expression " +
s"because its data type ${expr.dataType.simpleString} is not a orderable " +
s"data type.")
}
@@ -172,7 +172,7 @@ trait CheckAnalysis {
// This is just a sanity check, our analysis rule PullOutNondeterministic should
// already pull out those nondeterministic expressions and evaluate them in
// a Project node.
- failAnalysis(s"nondeterministic expression ${expr.prettyString} should not " +
+ failAnalysis(s"nondeterministic expression ${expr.sql} should not " +
s"appear in grouping expression.")
}
}
@@ -217,7 +217,7 @@ trait CheckAnalysis {
case p @ Project(exprs, _) if containsMultipleGenerators(exprs) =>
failAnalysis(
s"""Only a single table generating function is allowed in a SELECT clause, found:
- | ${exprs.map(_.prettyString).mkString(",")}""".stripMargin)
+ | ${exprs.map(_.sql).mkString(",")}""".stripMargin)
case j: Join if !j.duplicateResolved =>
val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet)
@@ -248,9 +248,9 @@ trait CheckAnalysis {
failAnalysis(
s"""nondeterministic expressions are only allowed in
|Project, Filter, Aggregate or Window, found:
- | ${o.expressions.map(_.prettyString).mkString(",")}
+ | ${o.expressions.map(_.sql).mkString(",")}
|in operator ${operator.simpleString}
- """.stripMargin)
+ """.stripMargin)
case _ => // Analysis successful!
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
index 5dfce89bd6..b49885d469 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
@@ -130,7 +130,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
new AttributeReference("gid", IntegerType, false)(isGenerated = true)
val groupByMap = a.groupingExpressions.collect {
case ne: NamedExpression => ne -> ne.toAttribute
- case e => e -> new AttributeReference(e.prettyString, e.dataType, e.nullable)()
+ case e => e -> new AttributeReference(e.sql, e.dataType, e.nullable)()
}
val groupByAttrs = groupByMap.map(_._2)
@@ -184,7 +184,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
val regularAggOperatorMap = regularAggExprs.map { e =>
// Perform the actual aggregation in the initial aggregate.
val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup)
- val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)()
+ val operator = Alias(e.copy(aggregateFunction = af), e.sql)()
// Select the result of the first aggregate in the last aggregate.
val result = AggregateExpression(
@@ -269,5 +269,5 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
// NamedExpression. This is done to prevent collisions between distinct and regular aggregate
// children, in this case attribute reuse causes the input of the regular aggregate to bound to
// the (nulled out) input of the distinct aggregate.
- e -> new AttributeReference(e.prettyString, e.dataType, true)()
+ e -> new AttributeReference(e.sql, e.dataType, true)()
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 79eebbf9b1..01afa01ae9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan}
import org.apache.spark.sql.catalyst.trees.TreeNode
+import org.apache.spark.sql.catalyst.util.quoteIdentifier
import org.apache.spark.sql.types.{DataType, StructType}
/**
@@ -67,6 +68,8 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un
override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName)
override def toString: String = s"'$name"
+
+ override def sql: String = quoteIdentifier(name)
}
object UnresolvedAttribute {
@@ -141,11 +144,8 @@ case class UnresolvedFunction(
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false
- override def prettyString: String = {
- s"${name}(${children.map(_.prettyString).mkString(",")})"
- }
-
- override def toString: String = s"'$name(${children.mkString(",")})"
+ override def prettyName: String = name
+ override def toString: String = s"'$name(${children.mkString(", ")})"
}
/**
@@ -208,10 +208,9 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu
Alias(extract, f.name)()
}
- case _ => {
+ case _ =>
throw new AnalysisException("Can only star expand struct data types. Attribute: `" +
target.get + "`")
- }
}
} else {
val from = input.inputSet.map(_.name).mkString(", ")
@@ -228,6 +227,7 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu
* For example the SQL expression "stack(2, key, value, key, value) as (a, b)" could be represented
* as follows:
* MultiAlias(stack_function, Seq(a, b))
+ *
* @param child the computation being performed
* @param names the names to be associated with each output of computing [[child]].
@@ -284,13 +284,14 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression)
override lazy val resolved = false
override def toString: String = s"$child[$extraction]"
+ override def sql: String = s"${child.sql}[${extraction.sql}]"
}
/**
* Holds the expression that has yet to be aliased.
*
* @param child The computation that is needs to be resolved during analysis.
- * @param aliasName The name if specified to be asoosicated with the result of computing [[child]]
+ * @param aliasName The name if specified to be associated with the result of computing [[child]]
*
*/
case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
index 04650d85de..c7be8e886c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
@@ -45,7 +45,7 @@ trait ExpectsInputTypes extends Expression {
val mismatches = children.zip(inputTypes).zipWithIndex.collect {
case ((child, expected), idx) if !expected.acceptsType(child.dataType) =>
s"argument ${idx + 1} requires ${expected.simpleString} type, " +
- s"however, '${child.prettyString}' is of ${child.dataType.simpleString} type."
+ s"however, '${child.sql}' is of ${child.dataType.simpleString} type."
}
if (mismatches.isEmpty) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index c73b2f8f2a..119496c7ee 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -18,10 +18,10 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{Analyzer, TypeCheckResult, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.trees.TreeNode
-import org.apache.spark.sql.catalyst.util.sequenceOption
+import org.apache.spark.sql.catalyst.util.toCommentSafeString
import org.apache.spark.sql.types._
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -96,7 +96,7 @@ abstract class Expression extends TreeNode[Expression] {
ctx.subExprEliminationExprs.get(this).map { subExprState =>
// This expression is repeated meaning the code to evaluated has already been added
// as a function and called in advance. Just use it.
- val code = s"/* ${this.toCommentSafeString} */"
+ val code = s"/* ${toCommentSafeString(this.toString)} */"
ExprCode(code, subExprState.isNull, subExprState.value)
}.getOrElse {
val isNull = ctx.freshName("isNull")
@@ -105,7 +105,7 @@ abstract class Expression extends TreeNode[Expression] {
ve.code = genCode(ctx, ve)
if (ve.code != "") {
// Add `this` in the comment.
- ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim)
+ ve.copy(s"/* ${toCommentSafeString(this.toString)} */\n" + ve.code.trim)
} else {
ve
}
@@ -201,17 +201,6 @@ abstract class Expression extends TreeNode[Expression] {
*/
def prettyName: String = getClass.getSimpleName.toLowerCase
- /**
- * Returns a user-facing string representation of this expression, i.e. does not have developer
- * centric debugging information like the expression id.
- */
- def prettyString: String = {
- transform {
- case a: AttributeReference => PrettyAttribute(a.name, a.dataType)
- case u: UnresolvedAttribute => PrettyAttribute(u.name)
- }.toString
- }
-
private def flatArguments = productIterator.flatMap {
case t: Traversable[_] => t
case single => single :: Nil
@@ -219,24 +208,16 @@ abstract class Expression extends TreeNode[Expression] {
override def simpleString: String = toString
- override def toString: String = prettyName + flatArguments.mkString("(", ",", ")")
+ override def toString: String = prettyName + flatArguments.mkString("(", ", ", ")")
/**
- * Returns the string representation of this expression that is safe to be put in
- * code comments of generated code.
+ * Returns SQL representation of this expression. For expressions extending [[NonSQLExpression]],
+ * this method may return an arbitrary user facing string.
*/
- protected def toCommentSafeString: String = this.toString
- .replace("*/", "\\*\\/")
- .replace("\\u", "\\\\u")
-
- /**
- * Returns SQL representation of this expression. For expressions that don't have a SQL
- * representation (e.g. `ScalaUDF`), this method should throw an `UnsupportedOperationException`.
- */
- @throws[UnsupportedOperationException](cause = "Expression doesn't have a SQL representation")
- def sql: String = throw new UnsupportedOperationException(
- s"Cannot map expression $this to its SQL representation"
- )
+ def sql: String = {
+ val childrenSQL = children.map(_.sql).mkString(", ")
+ s"$prettyName($childrenSQL)"
+ }
}
@@ -255,6 +236,19 @@ trait Unevaluable extends Expression {
/**
+ * Expressions that don't have SQL representation should extend this trait. Examples are
+ * `ScalaUDF`, `ScalaUDAF`, and object expressions like `MapObjects` and `Invoke`.
+ */
+trait NonSQLExpression extends Expression {
+ override def sql: String = {
+ transform {
+ case a: Attribute => new PrettyAttribute(a)
+ }.toString
+ }
+}
+
+
+/**
* An expression that is nondeterministic.
*/
trait Nondeterministic extends Expression {
@@ -373,8 +367,6 @@ abstract class UnaryExpression extends Expression {
"""
}
}
-
- override def sql: String = s"($prettyName(${child.sql}))"
}
@@ -477,8 +469,6 @@ abstract class BinaryExpression extends Expression {
"""
}
}
-
- override def sql: String = s"$prettyName(${left.sql}, ${right.sql})"
}
@@ -499,6 +489,8 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {
def symbol: String
+ def sqlOperator: String = symbol
+
override def toString: String = s"($left $symbol $right)"
override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType)
@@ -506,17 +498,17 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {
override def checkInputDataTypes(): TypeCheckResult = {
// First check whether left and right have the same type, then check if the type is acceptable.
if (left.dataType != right.dataType) {
- TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
+ TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " +
s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).")
} else if (!inputType.acceptsType(left.dataType)) {
- TypeCheckResult.TypeCheckFailure(s"'$prettyString' requires ${inputType.simpleString} type," +
+ TypeCheckResult.TypeCheckFailure(s"'$sql' requires ${inputType.simpleString} type," +
s" not ${left.dataType.simpleString}")
} else {
TypeCheckResult.TypeCheckSuccess
}
}
- override def sql: String = s"(${left.sql} $symbol ${right.sql})"
+ override def sql: String = s"(${left.sql} $sqlOperator ${right.sql})"
}
@@ -623,9 +615,4 @@ abstract class TernaryExpression extends Expression {
"""
}
}
-
- override def sql: String = {
- val childrenSQL = children.map(_.sql).mkString(", ")
- s"$prettyName($childrenSQL)"
- }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 681694746b..22184f1ddf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -39,11 +39,11 @@ case class ScalaUDF(
dataType: DataType,
children: Seq[Expression],
inputTypes: Seq[DataType] = Nil)
- extends Expression with ImplicitCastInputTypes {
+ extends Expression with ImplicitCastInputTypes with NonSQLExpression {
override def nullable: Boolean = true
- override def toString: String = s"UDF(${children.mkString(",")})"
+ override def toString: String = s"UDF(${children.mkString(", ")})"
// scalastyle:off line.size.limit
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index f88a57a254..ff3064ac66 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.catalyst.util.sequenceOption
import org.apache.spark.sql.types._
/** The mode of an [[AggregateFunction]]. */
@@ -92,8 +91,6 @@ private[sql] case class AggregateExpression(
AttributeSet(childReferences)
}
- override def prettyString: String = aggregateFunction.prettyString
-
override def toString: String = s"($aggregateFunction,mode=$mode,isDistinct=$isDistinct)"
override def sql: String = aggregateFunction.sql(isDistinct)
@@ -168,7 +165,7 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu
}
def sql(isDistinct: Boolean): String = {
- val distinct = if (isDistinct) "DISTINCT " else " "
+ val distinct = if (isDistinct) "DISTINCT " else ""
s"$prettyName($distinct${children.map(_.sql).mkString(", ")})"
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 1cacd3f76a..5af234609d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -319,7 +319,9 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
}
}
-case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
+case class MaxOf(left: Expression, right: Expression)
+ extends BinaryArithmetic with NonSQLExpression {
+
// TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least.
override def inputType: AbstractDataType = TypeCollection.Ordered
@@ -373,7 +375,9 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "max"
}
-case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
+case class MinOf(left: Expression, right: Expression)
+ extends BinaryArithmetic with NonSQLExpression {
+
// TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least.
override def inputType: AbstractDataType = TypeCollection.Ordered
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
index a97bd9edce..4c90b3f7d3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
@@ -123,4 +123,6 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp
}
protected override def nullSafeEval(input: Any): Any = not(input)
+
+ override def sql: String = s"~${child.sql}"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
index f58a2daf90..1365ee4b55 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic}
+import org.apache.spark.sql.catalyst.util.toCommentSafeString
/**
* A trait that can be used to provide a fallback mode for expression code generation.
@@ -37,7 +38,7 @@ trait CodegenFallback extends Expression {
val objectTerm = ctx.freshName("obj")
if (nullable) {
s"""
- /* expression: ${this.toCommentSafeString} */
+ /* expression: ${toCommentSafeString(this.toString)} */
Object $objectTerm = ((Expression) references[$idx]).eval($input);
boolean ${ev.isNull} = $objectTerm == null;
${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)};
@@ -48,7 +49,7 @@ trait CodegenFallback extends Expression {
} else {
ev.isNull = "false"
s"""
- /* expression: ${this.toCommentSafeString} */
+ /* expression: ${toCommentSafeString(this.toString)} */
Object $objectTerm = ((Expression) references[$idx]).eval($input);
${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 6b24fae9f3..44cdc8d881 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
+import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -93,6 +93,8 @@ object ExtractValue {
}
}
+trait ExtractValue extends Expression
+
/**
* Returns the value of fields in the Struct `child`.
*
@@ -102,13 +104,15 @@ object ExtractValue {
* For example, when get field `yEAr` from `<year: int, month: int>`, we should pass in `yEAr`.
*/
case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None)
- extends UnaryExpression {
+ extends UnaryExpression with ExtractValue {
private[sql] lazy val childSchema = child.dataType.asInstanceOf[StructType]
override def dataType: DataType = childSchema(ordinal).dataType
override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable
override def toString: String = s"$child.${name.getOrElse(childSchema(ordinal).name)}"
+ override def sql: String =
+ child.sql + s".${quoteIdentifier(name.getOrElse(childSchema(ordinal).name))}"
protected override def nullSafeEval(input: Any): Any =
input.asInstanceOf[InternalRow].get(ordinal, childSchema(ordinal).dataType)
@@ -130,12 +134,11 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String]
}
})
}
-
- override def sql: String = child.sql + s".`${childSchema(ordinal).name}`"
}
/**
- * Returns the array of value of fields in the Array of Struct `child`.
+ * For a child whose data type is an array of structs, extracts the `ordinal`-th fields of all array
+ * elements, and returns them as a new array.
*
* No need to do type checking since it is handled by [[ExtractValue]].
*/
@@ -144,10 +147,11 @@ case class GetArrayStructFields(
field: StructField,
ordinal: Int,
numFields: Int,
- containsNull: Boolean) extends UnaryExpression {
+ containsNull: Boolean) extends UnaryExpression with ExtractValue {
override def dataType: DataType = ArrayType(field.dataType, containsNull)
override def toString: String = s"$child.${field.name}"
+ override def sql: String = s"${child.sql}.${quoteIdentifier(field.name)}"
protected override def nullSafeEval(input: Any): Any = {
val array = input.asInstanceOf[ArrayData]
@@ -204,12 +208,13 @@ case class GetArrayStructFields(
* We need to do type checking here as `ordinal` expression maybe unresolved.
*/
case class GetArrayItem(child: Expression, ordinal: Expression)
- extends BinaryExpression with ExpectsInputTypes {
+ extends BinaryExpression with ExpectsInputTypes with ExtractValue {
// We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType)
override def toString: String = s"$child[$ordinal]"
+ override def sql: String = s"${child.sql}[${ordinal.sql}]"
override def left: Expression = child
override def right: Expression = ordinal
@@ -250,7 +255,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
* We need to do type checking here as `key` expression maybe unresolved.
*/
case class GetMapValue(child: Expression, key: Expression)
- extends BinaryExpression with ExpectsInputTypes {
+ extends BinaryExpression with ExpectsInputTypes with ExtractValue {
private def keyType = child.dataType.asInstanceOf[MapType].keyType
@@ -258,6 +263,7 @@ case class GetMapValue(child: Expression, key: Expression)
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType)
override def toString: String = s"$child[$key]"
+ override def sql: String = s"${child.sql}[${key.sql}]"
override def left: Expression = child
override def right: Expression = key
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 1eff2c4dd0..200c6a05df 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -35,7 +35,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
TypeCheckResult.TypeCheckFailure(
s"type of predicate expression in If should be boolean, not ${predicate.dataType}")
} else if (trueValue.dataType != falseValue.dataType) {
- TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
+ TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " +
s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).")
} else {
TypeCheckResult.TypeCheckSuccess
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index ca0892eb42..d7d768babc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -238,37 +238,20 @@ case class Literal protected (value: Any, dataType: DataType)
}
override def sql: String = (value, dataType) match {
- case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null =>
- "NULL"
-
- case _ if value == null =>
- s"CAST(NULL AS ${dataType.sql})"
-
+ case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null => "NULL"
+ case _ if value == null => s"CAST(NULL AS ${dataType.sql})"
case (v: UTF8String, StringType) =>
// Escapes all backslashes and double quotes.
"\"" + v.toString.replace("\\", "\\\\").replace("\"", "\\\"") + "\""
-
- case (v: Byte, ByteType) =>
- s"CAST($v AS ${ByteType.simpleString.toUpperCase})"
-
- case (v: Short, ShortType) =>
- s"CAST($v AS ${ShortType.simpleString.toUpperCase})"
-
- case (v: Long, LongType) =>
- s"CAST($v AS ${LongType.simpleString.toUpperCase})"
-
- case (v: Float, FloatType) =>
- s"CAST($v AS ${FloatType.simpleString.toUpperCase})"
-
- case (v: Decimal, DecimalType.Fixed(precision, scale)) =>
- s"CAST($v AS ${DecimalType.simpleString.toUpperCase}($precision, $scale))"
-
- case (v: Int, DateType) =>
- s"DATE '${DateTimeUtils.toJavaDate(v)}'"
-
- case (v: Long, TimestampType) =>
- s"TIMESTAMP('${DateTimeUtils.toJavaTimestamp(v)}')"
-
+ case (v: Byte, ByteType) => v + "Y"
+ case (v: Short, ShortType) => v + "S"
+ case (v: Long, LongType) => v + "L"
+ // Float type doesn't have a suffix
+ case (v: Float, FloatType) => s"CAST($v AS ${FloatType.sql})"
+ case (v: Double, DoubleType) => v + "D"
+ case (v: Decimal, t: DecimalType) => s"CAST($v AS ${t.sql})"
+ case (v: Int, DateType) => s"DATE '${DateTimeUtils.toJavaDate(v)}'"
+ case (v: Long, TimestampType) => s"TIMESTAMP('${DateTimeUtils.toJavaTimestamp(v)}')"
case _ => value.toString
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index 8b9a60f97c..bc2df0fb4a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -42,6 +42,7 @@ abstract class LeafMathExpression(c: Double, name: String)
override def foldable: Boolean = true
override def nullable: Boolean = false
override def toString: String = s"$name()"
+ override def prettyName: String = name
override def eval(input: InternalRow): Any = c
}
@@ -59,6 +60,7 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String)
override def dataType: DataType = DoubleType
override def nullable: Boolean = true
override def toString: String = s"$name($child)"
+ override def prettyName: String = name
protected override def nullSafeEval(input: Any): Any = {
f(input.asInstanceOf[Double])
@@ -70,8 +72,6 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String)
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)")
}
-
- override def sql: String = s"$name(${child.sql})"
}
abstract class UnaryLogExpression(f: Double => Double, name: String)
@@ -113,6 +113,8 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
override def toString: String = s"$name($left, $right)"
+ override def prettyName: String = name
+
override def dataType: DataType = DoubleType
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 207b8a0a88..1af5437647 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -22,6 +22,7 @@ import java.util.UUID
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.util.quoteIdentifier
import org.apache.spark.sql.types._
object NamedExpression {
@@ -183,9 +184,9 @@ case class Alias(child: Expression, name: String)(
override def sql: String = {
val qualifiersString =
- if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".")
+ if (qualifiers.isEmpty) "" else qualifiers.map(quoteIdentifier).mkString("", ".", ".")
val aliasName = if (isGenerated) s"$name#${exprId.id}" else s"$name"
- s"${child.sql} AS $qualifiersString`$aliasName`"
+ s"${child.sql} AS $qualifiersString${quoteIdentifier(aliasName)}"
}
}
@@ -300,9 +301,9 @@ case class AttributeReference(
override def sql: String = {
val qualifiersString =
- if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".")
+ if (qualifiers.isEmpty) "" else qualifiers.map(quoteIdentifier).mkString("", ".", ".")
val attrRefName = if (isGenerated) s"$name#${exprId.id}" else s"$name"
- s"$qualifiersString`$attrRefName`"
+ s"$qualifiersString${quoteIdentifier(attrRefName)}"
}
}
@@ -310,10 +311,19 @@ case class AttributeReference(
* A place holder used when printing expressions without debugging information such as the
* expression id or the unresolved indicator.
*/
-case class PrettyAttribute(name: String, dataType: DataType = NullType)
+case class PrettyAttribute(
+ name: String,
+ dataType: DataType = NullType)
extends Attribute with Unevaluable {
+ def this(attribute: Attribute) = this(attribute.name, attribute match {
+ case a: AttributeReference => a.dataType
+ case a: PrettyAttribute => a.dataType
+ case _ => NullType
+ })
+
override def toString: String = name
+ override def sql: String = toString
override def withNullability(newNullability: Boolean): Attribute =
throw new UnsupportedOperationException
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index 667d3513d3..e22026d584 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -83,8 +83,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
"""
}.mkString("\n")
}
-
- override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index fef6825b2d..737346dc79 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -46,7 +46,7 @@ case class StaticInvoke(
dataType: DataType,
functionName: String,
arguments: Seq[Expression] = Nil,
- propagateNull: Boolean = true) extends Expression {
+ propagateNull: Boolean = true) extends Expression with NonSQLExpression {
val objectName = staticObject.getName.stripSuffix("$")
@@ -108,7 +108,7 @@ case class Invoke(
targetObject: Expression,
functionName: String,
dataType: DataType,
- arguments: Seq[Expression] = Nil) extends Expression {
+ arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression {
override def nullable: Boolean = true
override def children: Seq[Expression] = arguments.+:(targetObject)
@@ -204,7 +204,7 @@ case class NewInstance(
arguments: Seq[Expression],
propagateNull: Boolean,
dataType: DataType,
- outerPointer: Option[Literal]) extends Expression {
+ outerPointer: Option[Literal]) extends Expression with NonSQLExpression {
private val className = cls.getName
override def nullable: Boolean = propagateNull
@@ -268,7 +268,7 @@ case class NewInstance(
*/
case class UnwrapOption(
dataType: DataType,
- child: Expression) extends UnaryExpression with ExpectsInputTypes {
+ child: Expression) extends UnaryExpression with NonSQLExpression with ExpectsInputTypes {
override def nullable: Boolean = true
@@ -298,7 +298,7 @@ case class UnwrapOption(
* @param optType The type of this option.
*/
case class WrapOption(child: Expression, optType: DataType)
- extends UnaryExpression with ExpectsInputTypes {
+ extends UnaryExpression with NonSQLExpression with ExpectsInputTypes {
override def dataType: DataType = ObjectType(classOf[Option[_]])
@@ -328,7 +328,7 @@ case class WrapOption(child: Expression, optType: DataType)
* manually, but will instead be passed into the provided lambda function.
*/
case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends LeafExpression
- with Unevaluable {
+ with Unevaluable with NonSQLExpression {
override def nullable: Boolean = true
@@ -368,7 +368,7 @@ object MapObjects {
case class MapObjects private(
loopVar: LambdaVariable,
lambdaFunction: Expression,
- inputData: Expression) extends Expression {
+ inputData: Expression) extends Expression with NonSQLExpression {
private def itemAccessorMethod(dataType: DataType): String => String = dataType match {
case NullType =>
@@ -483,7 +483,7 @@ case class MapObjects private(
*
* @param children A list of expression to use as content of the external row.
*/
-case class CreateExternalRow(children: Seq[Expression]) extends Expression {
+case class CreateExternalRow(children: Seq[Expression]) extends Expression with NonSQLExpression {
override def dataType: DataType = ObjectType(classOf[Row])
override def nullable: Boolean = false
@@ -516,7 +516,8 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression {
* Serializes an input object using a generic serializer (Kryo or Java).
* @param kryo if true, use Kryo. Otherwise, use Java.
*/
-case class EncodeUsingSerializer(child: Expression, kryo: Boolean) extends UnaryExpression {
+case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
+ extends UnaryExpression with NonSQLExpression {
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
@@ -558,7 +559,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) extends Unary
* @param kryo if true, use Kryo. Otherwise, use Java.
*/
case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean)
- extends UnaryExpression {
+ extends UnaryExpression with NonSQLExpression {
override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = {
// Code to initialize the serializer.
@@ -596,7 +597,7 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
* Initialize a Java Bean instance by setting its field values via setters.
*/
case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Expression])
- extends Expression {
+ extends Expression with NonSQLExpression {
override def nullable: Boolean = beanInstance.nullable
override def children: Seq[Expression] = beanInstance +: setters.values.toSeq
@@ -638,7 +639,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
* non-null `s`, `s.i` can't be null.
*/
case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
- extends UnaryExpression {
+ extends UnaryExpression with NonSQLExpression {
override def dataType: DataType = child.dataType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index c290aa8825..20818bfb1a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -249,6 +249,8 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
override def symbol: String = "&&"
+ override def sqlOperator: String = "AND"
+
override def eval(input: InternalRow): Any = {
val input1 = left.eval(input)
if (input1 == false) {
@@ -289,8 +291,6 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
}
"""
}
-
- override def sql: String = s"(${left.sql} AND ${right.sql})"
}
@@ -300,6 +300,8 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
override def symbol: String = "||"
+ override def sqlOperator: String = "OR"
+
override def eval(input: InternalRow): Any = {
val input1 = left.eval(input)
if (input1 == true) {
@@ -340,8 +342,6 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
}
"""
}
-
- override def sql: String = s"(${left.sql} OR ${right.sql})"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index b965212f27..4be065b30a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -23,7 +23,6 @@ import java.util.{HashMap, Locale, Map => JMap}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.ArrayData
-import org.apache.spark.sql.catalyst.util.sequenceOption
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
@@ -62,8 +61,6 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas
}
"""
}
-
- override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})"
}
@@ -156,8 +153,6 @@ case class ConcatWs(children: Seq[Expression])
"""
}
}
-
- override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})"
}
trait String2StringExpression extends ImplicitCastInputTypes {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
index afe122f6a0..9e6bd0ee46 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
@@ -556,8 +556,8 @@ abstract class RankLike extends AggregateWindowFunction {
override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType)
/** Store the values of the window 'order' expressions. */
- protected val orderAttrs = children.map{ expr =>
- AttributeReference(expr.prettyString, expr.dataType)()
+ protected val orderAttrs = children.map { expr =>
+ AttributeReference(expr.sql, expr.dataType)()
}
/** Predicate that detects if the order attributes have changed. */
@@ -636,7 +636,7 @@ case class DenseRank(children: Seq[Expression]) extends RankLike {
/**
* The PercentRank function computes the percentage ranking of a value in a group of values. The
- * result the rank of the minus one divided by the total number of rows in the partitiion minus one:
+ * result the rank of the minus one divided by the total number of rows in the partition minus one:
* (r - 1) / (n - 1). If a partition only contains one row, the function will return 0.
*
* The PercentRank function is similar to the CumeDist function, but it uses rank values instead of
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 7d155ac183..a19ba38ba0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -575,7 +575,7 @@ case class Pivot(
override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match {
case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)())
case _ => pivotValues.flatMap{ value =>
- aggregates.map(agg => AttributeReference(value + "_" + agg.prettyString, agg.dataType)())
+ aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)())
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
index 7a0d0de632..43f707f444 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
@@ -19,6 +19,9 @@ package org.apache.spark.sql.catalyst
import java.io._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.{NumericType, StringType}
+import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
package object util {
@@ -130,19 +133,31 @@ package object util {
ret
}
+ // Replaces attributes, string literals, complex type extractors with their pretty form so that
+ // generated column names don't contain back-ticks or double-quotes.
+ def usePrettyExpression(e: Expression): Expression = e transform {
+ case a: Attribute => new PrettyAttribute(a)
+ case Literal(s: UTF8String, StringType) => PrettyAttribute(s.toString, StringType)
+ case Literal(v, t: NumericType) if v != null => PrettyAttribute(v.toString, t)
+ case e: GetStructField =>
+ val name = e.name.getOrElse(e.childSchema(e.ordinal).name)
+ PrettyAttribute(usePrettyExpression(e.child).sql + "." + name, e.dataType)
+ case e: GetArrayStructFields =>
+ PrettyAttribute(usePrettyExpression(e.child) + "." + e.field.name, e.dataType)
+ }
+
+ def quoteIdentifier(name: String): String = {
+ // Escapes back-ticks within the identifier name with double-back-ticks, and then quote the
+ // identifier with back-ticks.
+ "`" + name.replace("`", "``") + "`"
+ }
+
/**
- * Converts a `Seq` of `Option[T]` to an `Option` of `Seq[T]`.
+ * Returns the string representation of this expression that is safe to be put in
+ * code comments of generated code.
*/
- def sequenceOption[T](seq: Seq[Option[T]]): Option[Seq[T]] = seq match {
- case xs if xs.isEmpty =>
- Option(Seq.empty[T])
-
- case xs =>
- for {
- head <- xs.head
- tail <- sequenceOption(xs.tail)
- } yield head +: tail
- }
+ def toCommentSafeString(str: String): String =
+ str.replace("*/", "\\*\\/").replace("\\u", "\\\\u")
/* FIX ME
implicit class debugLogging(a: Any) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index 5dd661ee6b..71ea5b8841 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -64,6 +64,8 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType {
override def toString: String = s"DecimalType($precision,$scale)"
+ override def sql: String = typeName.toUpperCase
+
/**
* Returns whether this DecimalType is wider than `other`. If yes, it means `other`
* can be casted into `this` safely without losing any precision or range.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index e797d83cb0..5ff5435d5a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -25,7 +25,7 @@ import org.json4s.JsonDSL._
import org.apache.spark.SparkException
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering}
-import org.apache.spark.sql.catalyst.util.{DataTypeParser, LegacyTypeStringParser}
+import org.apache.spark.sql.catalyst.util.{quoteIdentifier, DataTypeParser, LegacyTypeStringParser}
/**
* :: DeveloperApi ::
@@ -280,7 +280,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
}
override def sql: String = {
- val fieldTypes = fields.map(f => s"`${f.name}`: ${f.dataType.sql}")
+ val fieldTypes = fields.map(f => s"${quoteIdentifier(f.name)}: ${f.dataType.sql}")
s"STRUCT<${fieldTypes.mkString(", ")}>"
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 69f78a097e..de9a56dc9c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -127,22 +127,24 @@ class AnalysisErrorSuite extends AnalysisTest {
errorTest(
"single invalid type, single arg",
testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)),
- "cannot resolve" :: "testfunction" :: "argument 1" :: "requires int type" ::
- "'null' is of date type" :: Nil)
+ "cannot resolve" :: "testfunction(CAST(NULL AS DATE))" :: "argument 1" :: "requires int type" ::
+ "'CAST(NULL AS DATE)' is of date type" :: Nil)
errorTest(
"single invalid type, second arg",
testRelation.select(
TestFunction(dateLit :: dateLit :: Nil, DateType :: IntegerType :: Nil).as('a)),
- "cannot resolve" :: "testfunction" :: "argument 2" :: "requires int type" ::
- "'null' is of date type" :: Nil)
+ "cannot resolve" :: "testfunction(CAST(NULL AS DATE), CAST(NULL AS DATE))" ::
+ "argument 2" :: "requires int type" ::
+ "'CAST(NULL AS DATE)' is of date type" :: Nil)
errorTest(
"multiple invalid type",
testRelation.select(
TestFunction(dateLit :: dateLit :: Nil, IntegerType :: IntegerType :: Nil).as('a)),
- "cannot resolve" :: "testfunction" :: "argument 1" :: "argument 2" ::
- "requires int type" :: "'null' is of date type" :: Nil)
+ "cannot resolve" :: "testfunction(CAST(NULL AS DATE), CAST(NULL AS DATE))" ::
+ "argument 1" :: "argument 2" :: "requires int type" ::
+ "'CAST(NULL AS DATE)' is of date type" :: Nil)
errorTest(
"invalid window function",
@@ -207,7 +209,7 @@ class AnalysisErrorSuite extends AnalysisTest {
errorTest(
"sorting by attributes are not from grouping expressions",
testRelation2.groupBy('a, 'c)('a, 'c, count('a).as("a3")).orderBy('b.asc),
- "cannot resolve" :: "'b'" :: "given input columns" :: "[a, c, a3]" :: Nil)
+ "cannot resolve" :: "'`b`'" :: "given input columns" :: "[a, c, a3]" :: Nil)
errorTest(
"non-boolean filters",
@@ -222,7 +224,7 @@ class AnalysisErrorSuite extends AnalysisTest {
errorTest(
"missing group by",
testRelation2.groupBy('a)('b),
- "'b'" :: "group by" :: Nil
+ "'`b`'" :: "group by" :: Nil
)
errorTest(
@@ -270,7 +272,7 @@ class AnalysisErrorSuite extends AnalysisTest {
"SPARK-9955: correct error message for aggregate",
// When parse SQL string, we will wrap aggregate expressions with UnresolvedAlias.
testRelation2.where('bad_column > 1).groupBy('a)(UnresolvedAlias(max('b))),
- "cannot resolve 'bad_column'" :: Nil)
+ "cannot resolve '`bad_column`'" :: Nil)
test("SPARK-6452 regression test") {
// CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s)
@@ -311,7 +313,7 @@ class AnalysisErrorSuite extends AnalysisTest {
case true =>
assertAnalysisSuccess(plan, true)
case false =>
- assertAnalysisError(plan, "expression a cannot be used as a grouping expression" :: Nil)
+ assertAnalysisError(plan, "expression `a` cannot be used as a grouping expression" :: Nil)
}
}
@@ -372,7 +374,7 @@ class AnalysisErrorSuite extends AnalysisTest {
Some(EqualTo(AttributeReference("a", BinaryType)(exprId = ExprId(2)),
AttributeReference("c", BinaryType)(exprId = ExprId(4)))))
- assertAnalysisError(plan, "binary type expression a cannot be used in join conditions" :: Nil)
+ assertAnalysisError(plan, "binary type expression `a` cannot be used in join conditions" :: Nil)
val plan2 =
Join(
@@ -386,6 +388,6 @@ class AnalysisErrorSuite extends AnalysisTest {
Some(EqualTo(AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)))))
- assertAnalysisError(plan2, "map type expression a cannot be used in join conditions" :: Nil)
+ assertAnalysisError(plan2, "map type expression `a` cannot be used in join conditions" :: Nil)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
index af214b7af0..2756c463cf 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
@@ -71,8 +71,17 @@ trait AnalysisTest extends PlanTest {
val e = intercept[AnalysisException] {
analyzer.checkAnalysis(analyzer.execute(inputPlan))
}
- assert(expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains),
- s"Expected to throw Exception contains: ${expectedErrors.mkString(", ")}, " +
- s"actually we get ${e.getMessage}")
+
+ if (!expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains)) {
+ fail(
+ s"""Exception message should contain the following substrings:
+ |
+ | ${expectedErrors.mkString("\n ")}
+ |
+ |Actual exception message:
+ |
+ | ${e.getMessage}
+ """.stripMargin)
+ }
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 59549e3998..92c8496fde 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -41,7 +41,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertSuccess(expr)
}
assert(e.getMessage.contains(
- s"cannot resolve '${expr.prettyString}' due to data type mismatch:"))
+ s"cannot resolve '${expr.sql}' due to data type mismatch:"))
assert(e.getMessage.contains(errorMessage))
}
@@ -52,7 +52,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
def assertErrorForDifferingTypes(expr: Expression): Unit = {
assertError(expr,
- s"differing types in '${expr.prettyString}'")
+ s"differing types in '${expr.sql}'")
}
test("check types for unary arithmetic") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index 8b02b63c6c..3ad0dae767 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -142,7 +142,7 @@ class EncoderResolutionSuite extends PlanTest {
}.message
assert(msg2 ==
s"""
- |Cannot up cast `b.b` from decimal(38,18) to bigint as it may truncate
+ |Cannot up cast `b`.`b` from decimal(38,18) to bigint as it may truncate
|The type path of the target object is:
|- field (class: "scala.Long", name: "b")
|- field (class: "org.apache.spark.sql.catalyst.encoders.StringLongClass", name: "b")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 2ab091e40a..6c7929c362 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -24,7 +24,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.DataTypeParser
+import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DataTypeParser}
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.types._
@@ -104,10 +104,9 @@ class Column(protected[sql] val expr: Expression) extends Logging {
def this(name: String) = this(name match {
case "*" => UnresolvedStar(None)
- case _ if name.endsWith(".*") => {
+ case _ if name.endsWith(".*") =>
val parts = UnresolvedAttribute.parseAttributeName(name.substring(0, name.length - 2))
UnresolvedStar(Some(parts))
- }
case _ => UnresolvedAttribute.quotedString(name)
})
@@ -123,6 +122,8 @@ class Column(protected[sql] val expr: Expression) extends Logging {
// make it a NamedExpression.
case u: UnresolvedAttribute => UnresolvedAlias(u)
+ case u: UnresolvedExtractValue => UnresolvedAlias(u)
+
case expr: NamedExpression => expr
// Leave an unaliased generator with an empty list of names since the analyzer will generate
@@ -131,7 +132,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
case jt: JsonTuple => MultiAlias(jt, Nil)
- case func: UnresolvedFunction => UnresolvedAlias(func, Some(func.prettyString))
+ case func: UnresolvedFunction => UnresolvedAlias(func, Some(usePrettyExpression(func).sql))
// If we have a top level Cast, there is a chance to give it a better alias, if there is a
// NamedExpression under this Cast.
@@ -139,13 +140,13 @@ class Column(protected[sql] val expr: Expression) extends Logging {
case Cast(ne: NamedExpression, to) => UnresolvedAlias(Cast(ne, to))
} match {
case ne: NamedExpression => ne
- case other => Alias(expr, expr.prettyString)()
+ case other => Alias(expr, usePrettyExpression(expr).sql)()
}
- case expr: Expression => Alias(expr, expr.prettyString)()
+ case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)()
}
- override def toString: String = expr.prettyString
+ override def toString: String = usePrettyExpression(expr).sql
override def equals(that: Any): Boolean = that match {
case that: Column => that.expr.equals(this.expr)
@@ -987,7 +988,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
if (extended) {
println(expr)
} else {
- println(expr.prettyString)
+ println(expr.sql)
}
// scalastyle:on println
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 76c09a285d..9674450118 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.util.usePrettyExpression
import org.apache.spark.sql.execution.{ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution}
import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
@@ -1359,7 +1360,8 @@ class DataFrame private[sql](
"min" -> ((child: Expression) => Min(child).toAggregateExpression()),
"max" -> ((child: Expression) => Max(child).toAggregateExpression()))
- val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList
+ val outputCols =
+ (if (cols.isEmpty) numericColumns.map(usePrettyExpression(_).sql) else cols).toList
val ret: Seq[Row] = if (outputCols.nonEmpty) {
val aggExprs = statistics.flatMap { case (_, colToAgg) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index c74ef2c035..66ec0e7338 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, Unresolved
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Pivot}
+import org.apache.spark.sql.catalyst.util.usePrettyExpression
import org.apache.spark.sql.types.NumericType
/**
@@ -74,7 +75,7 @@ class GroupedData protected[sql](
private[this] def alias(expr: Expression): NamedExpression = expr match {
case u: UnresolvedAttribute => UnresolvedAlias(u)
case expr: NamedExpression => expr
- case expr: Expression => Alias(expr, expr.prettyString)()
+ case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)()
}
private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index ad8564e96a..990eeb22b6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.util.toCommentSafeString
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight}
import org.apache.spark.sql.execution.metric.LongSQLMetricValue
@@ -252,7 +253,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
}
/** Codegened pipeline for:
- * ${plan.treeString.trim}
+ * ${toCommentSafeString(plan.treeString.trim)}
*/
class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index 812e696338..ab178443dd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, MutableRow}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, MutableRow, _}
import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
@@ -324,7 +324,7 @@ private[sql] case class ScalaUDAF(
udaf: UserDefinedAggregateFunction,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
- extends ImperativeAggregate with Logging {
+ extends ImperativeAggregate with NonSQLExpression with Logging {
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala
index 1c0d53fc77..54dda0c391 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala
@@ -544,7 +544,7 @@ private[parquet] object CatalystSchemaConverter {
!name.matches(".*[ ,;{}()\n\t=].*"),
s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=".
|Please use alias to rename it.
- """.stripMargin.split("\n").mkString(" "))
+ """.stripMargin.split("\n").mkString(" ").trim)
}
def checkFieldNames(schema: StructType): StructType = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
index 0e53a0c473..9aff0be716 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.python
import org.apache.spark.{Accumulator, Logging}
import org.apache.spark.api.python.PythonBroadcast
import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.sql.catalyst.expressions.{Expression, Unevaluable}
+import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable}
import org.apache.spark.sql.types.DataType
/**
@@ -36,9 +36,12 @@ case class PythonUDF(
broadcastVars: java.util.List[Broadcast[PythonBroadcast]],
accumulator: Accumulator[java.util.List[Array[Byte]]],
dataType: DataType,
- children: Seq[Expression]) extends Expression with Unevaluable with Logging {
+ children: Seq[Expression])
+ extends Expression with Unevaluable with NonSQLExpression with Logging {
- override def toString: String = s"PythonUDF#$name(${children.mkString(",")})"
+ override def toString: String = s"PythonUDF#$name(${children.mkString(", ")})"
override def nullable: Boolean = true
+
+ override def sql: String = s"$name(${children.mkString(", ")})"
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index f9ba607700..498f007081 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -525,7 +525,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val e = intercept[AnalysisException] {
ds.as[ClassData2]
}
- assert(e.getMessage.contains("cannot resolve 'c' given input columns: [a, b]"), e.getMessage)
+ assert(e.getMessage.contains("cannot resolve '`c`' given input columns: [a, b]"), e.getMessage)
}
test("runtime nullability check") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index bd87449f92..41a9404b00 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -455,7 +455,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
sqlContext.udf.register("div0", (x: Int) => x / 0)
withTempPath { dir =>
intercept[org.apache.spark.SparkException] {
- sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath)
+ sqlContext.sql("select div0(1) as div0").write.parquet(dir.getCanonicalPath)
}
val path = new Path(dir.getCanonicalPath, "_temporary")
val fs = path.getFileSystem(hadoopConfiguration)
@@ -479,7 +479,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
sqlContext.udf.register("div0", (x: Int) => x / 0)
withTempPath { dir =>
intercept[org.apache.spark.SparkException] {
- sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath)
+ sqlContext.sql("select div0(1) as div0").write.parquet(dir.getCanonicalPath)
}
val path = new Path(dir.getCanonicalPath, "_temporary")
val fs = path.getFileSystem(hadoopConfiguration)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
index bf5edb4759..1dda39d44e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
@@ -24,10 +24,11 @@ import scala.util.control.NonFatal
import org.apache.spark.Logging
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, NonSQLExpression, SortOrder}
import org.apache.spark.sql.catalyst.optimizer.CollapseProject
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
+import org.apache.spark.sql.catalyst.util.quoteIdentifier
import org.apache.spark.sql.execution.datasources.LogicalRelation
/**
@@ -37,11 +38,21 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation
* supported by this builder (yet).
*/
class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging {
+ require(logicalPlan.resolved, "SQLBuilder only supports resloved logical query plans")
+
def this(df: DataFrame) = this(df.queryExecution.analyzed, df.sqlContext)
def toSQL: String = {
val canonicalizedPlan = Canonicalizer.execute(logicalPlan)
try {
+ canonicalizedPlan.transformAllExpressions {
+ case e: NonSQLExpression =>
+ throw new UnsupportedOperationException(
+ s"Expression $e doesn't have a SQL representation"
+ )
+ case e => e
+ }
+
val generatedSQL = toSQL(canonicalizedPlan)
logDebug(
s"""Built SQL query string successfully from given logical plan:
@@ -95,7 +106,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
p.child match {
// Persisted data source relation
case LogicalRelation(_, _, Some(TableIdentifier(table, Some(database)))) =>
- s"`$database`.`$table`"
+ s"${quoteIdentifier(database)}.${quoteIdentifier(table)}"
// Parentheses is not used for persisted data source relations
// e.g., select x.c1 from (t1) as x inner join (t1) as y on x.c1 = y.c1
case Subquery(_, _: LogicalRelation | _: MetastoreRelation) =>
@@ -114,8 +125,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
case p: MetastoreRelation =>
build(
- s"`${p.databaseName}`.`${p.tableName}`",
- p.alias.map(a => s" AS `$a`").getOrElse("")
+ s"${quoteIdentifier(p.databaseName)}.${quoteIdentifier(p.tableName)}",
+ p.alias.map(a => s" AS ${quoteIdentifier(a)}").getOrElse("")
)
case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _))
@@ -148,7 +159,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
* The segments are trimmed so only a single space appears in the separation.
* For example, `build("a", " b ", " c")` becomes "a b c".
*/
- private def build(segments: String*): String = segments.map(_.trim).mkString(" ")
+ private def build(segments: String*): String =
+ segments.map(_.trim).filter(_.nonEmpty).mkString(" ")
private def projectToSQL(plan: Project, isDistinct: Boolean): String = {
build(
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index d5ed838ca4..bcafa045e0 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.catalyst.util.sequenceOption
import org.apache.spark.sql.hive.HiveShim._
import org.apache.spark.sql.hive.client.HiveClientImpl
import org.apache.spark.sql.types._
@@ -70,18 +69,28 @@ private[hive] class HiveFunctionRegistry(
// catch the exception and throw AnalysisException instead.
try {
if (classOf[GenericUDFMacro].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveGenericUDF(
+ val udf = HiveGenericUDF(
name, new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children)
+ udf.dataType // Force it to check input data types.
+ udf
} else if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children)
+ val udf = HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children)
+ udf.dataType // Force it to check input data types.
+ udf
} else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children)
+ val udf = HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children)
+ udf.dataType // Force it to check input data types.
+ udf
} else if (
classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children)
+ val udaf = HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children)
+ udaf.dataType // Force it to check input data types.
+ udaf
} else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveUDAFFunction(
+ val udaf = HiveUDAFFunction(
name, new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true)
+ udaf.dataType // Force it to check input data types.
+ udaf
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(functionClassName), children)
udtf.elementTypes // Force it to check input data types.
@@ -163,7 +172,7 @@ private[hive] case class HiveSimpleUDF(
@transient
private lazy val conversionHelper = new ConversionHelper(method, arguments)
- override val dataType = javaClassToDataType(method.getReturnType)
+ override lazy val dataType = javaClassToDataType(method.getReturnType)
@transient
lazy val returnInspector = ObjectInspectorFactory.getReflectionObjectInspector(
@@ -189,6 +198,8 @@ private[hive] case class HiveSimpleUDF(
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
}
+ override def prettyName: String = name
+
override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})"
}
@@ -233,11 +244,11 @@ private[hive] case class HiveGenericUDF(
}
@transient
- private lazy val deferedObjects = argumentInspectors.zip(children).map { case (inspect, child) =>
+ private lazy val deferredObjects = argumentInspectors.zip(children).map { case (inspect, child) =>
new DeferredObjectAdapter(inspect, child.dataType)
}.toArray[DeferredObject]
- override val dataType: DataType = inspectorToDataType(returnInspector)
+ override lazy val dataType: DataType = inspectorToDataType(returnInspector)
override def eval(input: InternalRow): Any = {
returnInspector // Make sure initialized.
@@ -245,20 +256,20 @@ private[hive] case class HiveGenericUDF(
var i = 0
while (i < children.length) {
val idx = i
- deferedObjects(i).asInstanceOf[DeferredObjectAdapter].set(
+ deferredObjects(i).asInstanceOf[DeferredObjectAdapter].set(
() => {
children(idx).eval(input)
})
i += 1
}
- unwrap(function.evaluate(deferedObjects), returnInspector)
+ unwrap(function.evaluate(deferredObjects), returnInspector)
}
+ override def prettyName: String = name
+
override def toString: String = {
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
}
-
- override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})"
}
/**
@@ -340,7 +351,7 @@ private[hive] case class HiveGenericUDTF(
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
}
- override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})"
+ override def prettyName: String = name
}
/**
@@ -432,7 +443,9 @@ private[hive] case class HiveUDAFFunction(
override def supportsPartial: Boolean = false
- override val dataType: DataType = inspectorToDataType(returnInspector)
+ override lazy val dataType: DataType = inspectorToDataType(returnInspector)
+
+ override def prettyName: String = name
override def sql(isDistinct: Boolean): String = {
val distinct = if (isDistinct) "DISTINCT " else " "
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala
index 3fb6543b1a..e4b4d1861a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala
@@ -26,17 +26,24 @@ class ExpressionSQLBuilderSuite extends SQLBuilderTest {
test("literal") {
checkSQL(Literal("foo"), "\"foo\"")
checkSQL(Literal("\"foo\""), "\"\\\"foo\\\"\"")
- checkSQL(Literal(1: Byte), "CAST(1 AS TINYINT)")
- checkSQL(Literal(2: Short), "CAST(2 AS SMALLINT)")
+ checkSQL(Literal(1: Byte), "1Y")
+ checkSQL(Literal(2: Short), "2S")
checkSQL(Literal(4: Int), "4")
- checkSQL(Literal(8: Long), "CAST(8 AS BIGINT)")
+ checkSQL(Literal(8: Long), "8L")
checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)")
- checkSQL(Literal(2.5D), "2.5")
+ checkSQL(Literal(2.5D), "2.5D")
checkSQL(
Literal(Timestamp.valueOf("2016-01-01 00:00:00")), "TIMESTAMP('2016-01-01 00:00:00.0')")
// TODO tests for decimals
}
+ test("attributes") {
+ checkSQL('a.int, "`a`")
+ checkSQL(Symbol("foo bar").int, "`foo bar`")
+ // Keyword
+ checkSQL('int.int, "`int`")
+ }
+
test("binary comparisons") {
checkSQL('a.int === 'b.int, "(`a` = `b`)")
checkSQL('a.int <=> 'b.int, "(`a` <=> `b`)")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
index dc8ac7e47f..5255b150aa 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
@@ -29,6 +29,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
sql("DROP TABLE IF EXISTS t0")
sql("DROP TABLE IF EXISTS t1")
sql("DROP TABLE IF EXISTS t2")
+
sqlContext.range(10).write.saveAsTable("t0")
sqlContext
@@ -168,4 +169,67 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
}
}
}
+
+ test("plans with non-SQL expressions") {
+ sqlContext.udf.register("foo", (_: Int) * 2)
+ intercept[UnsupportedOperationException](new SQLBuilder(sql("SELECT foo(id) FROM t0")).toSQL)
+ }
+
+ test("named expression in column names shouldn't be quoted") {
+ def checkColumnNames(query: String, expectedColNames: String*): Unit = {
+ checkHiveQl(query)
+ assert(sql(query).columns === expectedColNames)
+ }
+
+ // Attributes
+ checkColumnNames(
+ """SELECT * FROM (
+ | SELECT 1 AS a, 2 AS b, 3 AS `we``ird`
+ |) s
+ """.stripMargin,
+ "a", "b", "we`ird"
+ )
+
+ checkColumnNames(
+ """SELECT x.a, y.a, x.b, y.b
+ |FROM (SELECT 1 AS a, 2 AS b) x
+ |INNER JOIN (SELECT 1 AS a, 2 AS b) y
+ |ON x.a = y.a
+ """.stripMargin,
+ "a", "a", "b", "b"
+ )
+
+ // String literal
+ checkColumnNames(
+ "SELECT 'foo', '\"bar\\''",
+ "foo", "\"bar\'"
+ )
+
+ // Numeric literals (should have CAST or suffixes in column names)
+ checkColumnNames(
+ "SELECT 1Y, 2S, 3, 4L, 5.1, 6.1D",
+ "1", "2", "3", "4", "5.1", "6.1"
+ )
+
+ // Aliases
+ checkColumnNames(
+ "SELECT 1 AS a",
+ "a"
+ )
+
+ // Complex type extractors
+ checkColumnNames(
+ """SELECT
+ | a.f1, b[0].f1, b.f1, c["foo"], d[0]
+ |FROM (
+ | SELECT
+ | NAMED_STRUCT("f1", 1, "f2", "foo") AS a,
+ | ARRAY(NAMED_STRUCT("f1", 1, "f2", "foo")) AS b,
+ | MAP("foo", 1) AS c,
+ | ARRAY(1) AS d
+ |) s
+ """.stripMargin,
+ "f1", "b[0].f1", "f1", "c[foo]", "d[0]"
+ )
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala
index cd055f9eca..d8d3448add 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala
@@ -29,8 +29,8 @@ class HivePlanTest extends QueryTest with TestHiveSingleton {
test("udf constant folding") {
Seq.empty[Tuple1[Int]].toDF("a").registerTempTable("t")
- val optimized = sql("SELECT cos(null) FROM t").queryExecution.optimizedPlan
- val correctAnswer = sql("SELECT cast(null as double) FROM t").queryExecution.optimizedPlan
+ val optimized = sql("SELECT cos(null) AS c FROM t").queryExecution.optimizedPlan
+ val correctAnswer = sql("SELECT cast(null as double) AS c FROM t").queryExecution.optimizedPlan
comparePlans(optimized, correctAnswer)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
index 27ea3e8041..fe446774ef 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
@@ -131,17 +131,17 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA
val df = sql(
"""
|SELECT
- | CAST(null as TINYINT),
- | CAST(null as SMALLINT),
- | CAST(null as INT),
- | CAST(null as BIGINT),
- | CAST(null as FLOAT),
- | CAST(null as DOUBLE),
- | CAST(null as DECIMAL(7,2)),
- | CAST(null as TIMESTAMP),
- | CAST(null as DATE),
- | CAST(null as STRING),
- | CAST(null as VARCHAR(10))
+ | CAST(null as TINYINT) as c0,
+ | CAST(null as SMALLINT) as c1,
+ | CAST(null as INT) as c2,
+ | CAST(null as BIGINT) as c3,
+ | CAST(null as FLOAT) as c4,
+ | CAST(null as DOUBLE) as c5,
+ | CAST(null as DECIMAL(7,2)) as c6,
+ | CAST(null as TIMESTAMP) as c7,
+ | CAST(null as DATE) as c8,
+ | CAST(null as STRING) as c9,
+ | CAST(null as VARCHAR(10)) as c10
|FROM orc_temp_table limit 1
""".stripMargin)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
index c997453803..a6ca7d0386 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
@@ -626,7 +626,10 @@ class ParquetSourceSuite extends ParquetPartitioningTest {
sql(
s"""CREATE TABLE array_of_struct
|STORED AS PARQUET LOCATION '$path'
- |AS SELECT '1st', '2nd', ARRAY(NAMED_STRUCT('a', 'val_a', 'b', 'val_b'))
+ |AS SELECT
+ | '1st' AS a,
+ | '2nd' AS b,
+ | ARRAY(NAMED_STRUCT('a', 'val_a', 'b', 'val_b')) AS c
""".stripMargin)
checkAnswer(