aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala20
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala47
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala16
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala3
9 files changed, 78 insertions, 36 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 4d53b232d5..62b241f052 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -416,9 +416,10 @@ class Analyzer(
case UnresolvedAlias(f @ UnresolvedFunction(_, args, _), _) if containsStar(args) =>
val newChildren = expandStarExpressions(args, child)
UnresolvedAlias(child = f.copy(children = newChildren)) :: Nil
- case Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) =>
+ case a @ Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) =>
val newChildren = expandStarExpressions(args, child)
- Alias(child = f.copy(children = newChildren), name)() :: Nil
+ Alias(child = f.copy(children = newChildren), name)(
+ isGenerated = a.isGenerated) :: Nil
case UnresolvedAlias(c @ CreateArray(args), _) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child, resolver)
@@ -528,7 +529,7 @@ class Analyzer(
def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
expressions.map {
- case a: Alias => Alias(a.child, a.name)()
+ case a: Alias => Alias(a.child, a.name)(isGenerated = a.isGenerated)
case other => other
}
}
@@ -734,7 +735,10 @@ class Analyzer(
// Try resolving the condition of the filter as though it is in the aggregate clause
val aggregatedCondition =
- Aggregate(grouping, Alias(havingCondition, "havingCondition")() :: Nil, child)
+ Aggregate(
+ grouping,
+ Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil,
+ child)
val resolvedOperator = execute(aggregatedCondition)
def resolvedAggregateFilter =
resolvedOperator
@@ -759,7 +763,8 @@ class Analyzer(
// Try resolving the ordering as though it is in the aggregate clause.
try {
val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s))
- val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")())
+ val aliasedOrdering =
+ unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")(isGenerated = true))
val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering)
val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate]
val resolvedAliasedOrdering: Seq[Alias] =
@@ -1190,7 +1195,7 @@ class Analyzer(
leafNondeterministic.map { e =>
val ne = e match {
case n: NamedExpression => n
- case _ => Alias(e, "_nondeterministic")()
+ case _ => Alias(e, "_nondeterministic")(isGenerated = true)
}
new TreeNodeRef(e) -> ne
}
@@ -1355,7 +1360,8 @@ object CleanupAliases extends Rule[LogicalPlan] {
def trimNonTopLevelAliases(e: Expression): Expression = e match {
case a: Alias =>
- Alias(trimAliases(a.child), a.name)(a.exprId, a.qualifiers, a.explicitMetadata)
+ Alias(trimAliases(a.child), a.name)(
+ a.exprId, a.qualifiers, a.explicitMetadata, a.isGenerated)
case other => trimAliases(other)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
index 4e7d134102..5dfce89bd6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
@@ -126,7 +126,8 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
// Aggregation strategy can handle the query with single distinct
if (distinctAggGroups.size > 1) {
// Create the attributes for the grouping id and the group by clause.
- val gid = new AttributeReference("gid", IntegerType, false)()
+ val gid =
+ new AttributeReference("gid", IntegerType, false)(isGenerated = true)
val groupByMap = a.groupingExpressions.collect {
case ne: NamedExpression => ne -> ne.toAttribute
case e => e -> new AttributeReference(e.prettyString, e.dataType, e.nullable)()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 7983501ada..207b8a0a88 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -79,6 +79,9 @@ trait NamedExpression extends Expression {
/** Returns the metadata when an expression is a reference to another expression with metadata. */
def metadata: Metadata = Metadata.empty
+ /** Returns true if the expression is generated by Catalyst */
+ def isGenerated: java.lang.Boolean = false
+
/** Returns a copy of this expression with a new `exprId`. */
def newInstance(): NamedExpression
@@ -114,16 +117,21 @@ abstract class Attribute extends LeafExpression with NamedExpression {
* Note that exprId and qualifiers are in a separate parameter list because
* we only pattern match on child and name.
*
- * @param child the computation being performed
- * @param name the name to be associated with the result of computing [[child]].
+ * @param child The computation being performed
+ * @param name The name to be associated with the result of computing [[child]].
* @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this
* alias. Auto-assigned if left blank.
+ * @param qualifiers A list of strings that can be used to referred to this attribute in a fully
+ * qualified way. Consider the examples tableName.name, subQueryAlias.name.
+ * tableName and subQueryAlias are possible qualifiers.
* @param explicitMetadata Explicit metadata associated with this alias that overwrites child's.
+ * @param isGenerated A flag to indicate if this alias is generated by Catalyst
*/
case class Alias(child: Expression, name: String)(
val exprId: ExprId = NamedExpression.newExprId,
val qualifiers: Seq[String] = Nil,
- val explicitMetadata: Option[Metadata] = None)
+ val explicitMetadata: Option[Metadata] = None,
+ override val isGenerated: java.lang.Boolean = false)
extends UnaryExpression with NamedExpression {
// Alias(Generator, xx) need to be transformed into Generate(generator, ...)
@@ -148,11 +156,13 @@ case class Alias(child: Expression, name: String)(
}
def newInstance(): NamedExpression =
- Alias(child, name)(qualifiers = qualifiers, explicitMetadata = explicitMetadata)
+ Alias(child, name)(
+ qualifiers = qualifiers, explicitMetadata = explicitMetadata, isGenerated = isGenerated)
override def toAttribute: Attribute = {
if (resolved) {
- AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifiers)
+ AttributeReference(name, child.dataType, child.nullable, metadata)(
+ exprId, qualifiers, isGenerated)
} else {
UnresolvedAttribute(name)
}
@@ -161,7 +171,7 @@ case class Alias(child: Expression, name: String)(
override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix"
override protected final def otherCopyArgs: Seq[AnyRef] = {
- exprId :: qualifiers :: explicitMetadata :: Nil
+ exprId :: qualifiers :: explicitMetadata :: isGenerated :: Nil
}
override def equals(other: Any): Boolean = other match {
@@ -174,7 +184,8 @@ case class Alias(child: Expression, name: String)(
override def sql: String = {
val qualifiersString =
if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".")
- s"${child.sql} AS $qualifiersString`$name`"
+ val aliasName = if (isGenerated) s"$name#${exprId.id}" else s"$name"
+ s"${child.sql} AS $qualifiersString`$aliasName`"
}
}
@@ -187,9 +198,10 @@ case class Alias(child: Expression, name: String)(
* @param metadata The metadata of this attribute.
* @param exprId A globally unique id used to check if different AttributeReferences refer to the
* same attribute.
- * @param qualifiers a list of strings that can be used to referred to this attribute in a fully
+ * @param qualifiers A list of strings that can be used to referred to this attribute in a fully
* qualified way. Consider the examples tableName.name, subQueryAlias.name.
* tableName and subQueryAlias are possible qualifiers.
+ * @param isGenerated A flag to indicate if this reference is generated by Catalyst
*/
case class AttributeReference(
name: String,
@@ -197,7 +209,8 @@ case class AttributeReference(
nullable: Boolean = true,
override val metadata: Metadata = Metadata.empty)(
val exprId: ExprId = NamedExpression.newExprId,
- val qualifiers: Seq[String] = Nil)
+ val qualifiers: Seq[String] = Nil,
+ override val isGenerated: java.lang.Boolean = false)
extends Attribute with Unevaluable {
/**
@@ -234,7 +247,8 @@ case class AttributeReference(
}
override def newInstance(): AttributeReference =
- AttributeReference(name, dataType, nullable, metadata)(qualifiers = qualifiers)
+ AttributeReference(name, dataType, nullable, metadata)(
+ qualifiers = qualifiers, isGenerated = isGenerated)
/**
* Returns a copy of this [[AttributeReference]] with changed nullability.
@@ -243,7 +257,7 @@ case class AttributeReference(
if (nullable == newNullability) {
this
} else {
- AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifiers)
+ AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifiers, isGenerated)
}
}
@@ -251,7 +265,7 @@ case class AttributeReference(
if (name == newName) {
this
} else {
- AttributeReference(newName, dataType, nullable)(exprId, qualifiers)
+ AttributeReference(newName, dataType, nullable, metadata)(exprId, qualifiers, isGenerated)
}
}
@@ -262,7 +276,7 @@ case class AttributeReference(
if (newQualifiers.toSet == qualifiers.toSet) {
this
} else {
- AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifiers)
+ AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifiers, isGenerated)
}
}
@@ -270,12 +284,12 @@ case class AttributeReference(
if (exprId == newExprId) {
this
} else {
- AttributeReference(name, dataType, nullable, metadata)(newExprId, qualifiers)
+ AttributeReference(name, dataType, nullable, metadata)(newExprId, qualifiers, isGenerated)
}
}
override protected final def otherCopyArgs: Seq[AnyRef] = {
- exprId :: qualifiers :: Nil
+ exprId :: qualifiers :: isGenerated :: Nil
}
override def toString: String = s"$name#${exprId.id}$typeSuffix"
@@ -287,7 +301,8 @@ case class AttributeReference(
override def sql: String = {
val qualifiersString =
if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".")
- s"$qualifiersString`$name`"
+ val attrRefName = if (isGenerated) s"$name#${exprId.id}" else s"$name"
+ s"$qualifiersString`$attrRefName`"
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index f0ee124e88..7302b63646 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -78,10 +78,13 @@ object PhysicalOperation extends PredicateHelper {
private def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = {
expr.transform {
case a @ Alias(ref: AttributeReference, name) =>
- aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a)
+ aliases.get(ref)
+ .map(Alias(_, name)(a.exprId, a.qualifiers, isGenerated = a.isGenerated))
+ .getOrElse(a)
case a: AttributeReference =>
- aliases.get(a).map(Alias(_, a.name)(a.exprId, a.qualifiers)).getOrElse(a)
+ aliases.get(a)
+ .map(Alias(_, a.name)(a.exprId, a.qualifiers, isGenerated = a.isGenerated)).getOrElse(a)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index d8944a4241..18b7bde906 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -139,7 +139,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
case a: Alias =>
// As the root of the expression, Alias will always take an arbitrary exprId, we need
// to erase that for equality testing.
- val cleanedExprId = Alias(a.child, a.name)(ExprId(-1), a.qualifiers)
+ val cleanedExprId =
+ Alias(a.child, a.name)(ExprId(-1), a.qualifiers, isGenerated = a.isGenerated)
BindReferences.bindReference(cleanedExprId, input, allowFailures = true)
case other => BindReferences.bindReference(other, input, allowFailures = true)
}
@@ -222,7 +223,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
nameParts: Seq[String],
resolver: Resolver,
attribute: Attribute): Option[(Attribute, List[String])] = {
- if (resolver(attribute.name, nameParts.head)) {
+ if (!attribute.isGenerated && resolver(attribute.name, nameParts.head)) {
Option((attribute.withName(nameParts.head), nameParts.tail.toList))
} else {
None
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 2df0683f9f..30df2a84f6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -656,6 +656,8 @@ object TreeNode {
case t if t <:< definitions.DoubleTpe =>
value.asInstanceOf[JDouble].num: java.lang.Double
+ case t if t <:< localTypeOf[java.lang.Boolean] =>
+ value.asInstanceOf[JBool].value: java.lang.Boolean
case t if t <:< localTypeOf[BigInt] => value.asInstanceOf[JInt].num
case t if t <:< localTypeOf[java.lang.String] => value.asInstanceOf[JString].s
case t if t <:< localTypeOf[UUID] => UUID.fromString(value.asInstanceOf[JString].s)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index fc35959f20..e0cec09742 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -23,10 +23,10 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count, Sum}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.Inner
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
@BeanInfo
@@ -177,6 +177,13 @@ class AnalysisErrorSuite extends AnalysisTest {
"cannot resolve" :: "abcd" :: Nil)
errorTest(
+ "unresolved attributes with a generated name",
+ testRelation2.groupBy('a)(max('b))
+ .where(sum('b) > 0)
+ .orderBy('havingCondition.asc),
+ "cannot resolve" :: "havingCondition" :: Nil)
+
+ errorTest(
"bad casts",
testRelation.select(Literal(1).cast(BinaryType).as('badCast)),
"cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index c02133ffc8..3ea4adcaa6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -998,12 +998,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
}
- test("SPARK-10034: Sort on Aggregate with aggregation expression named 'aggOrdering'") {
+ test("Alias uses internally generated names 'aggOrder' and 'havingCondition'") {
val df = Seq(1 -> 2).toDF("i", "j")
- val query = df.groupBy('i)
- .agg(max('j).as("aggOrdering"))
+ val query1 = df.groupBy('i)
+ .agg(max('j).as("aggOrder"))
.orderBy(sum('j))
- checkAnswer(query, Row(1, 2))
+ checkAnswer(query1, Row(1, 2))
+
+ // In the plan, there are two attributes having the same name 'havingCondition'
+ // One is a user-provided alias name; another is an internally generated one.
+ val query2 = df.groupBy('i)
+ .agg(max('j).as("havingCondition"))
+ .where(sum('j) > 0)
+ .orderBy('havingCondition.asc)
+ checkAnswer(query2, Row(1, 2))
}
test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
index 1f731db26f..129bfe0a7d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
@@ -92,12 +92,11 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY MAX(key)")
}
- // TODO Fix name collision introduced by ResolveAggregateFunction analysis rule
// When there are multiple aggregate functions in ORDER BY clause, all of them are extracted into
// Aggregate operator and aliased to the same name "aggOrder". This is OK for normal query
// execution since these aliases have different expression ID. But this introduces name collision
// when converting resolved plans back to SQL query strings as expression IDs are stripped.
- ignore("aggregate function in order by clause with multiple order keys") {
+ test("aggregate function in order by clause with multiple order keys") {
checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY key, MAX(key)")
}