aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2014-08-26 16:29:14 -0700
committerReynold Xin <rxin@apache.org>2014-08-26 16:29:14 -0700
commitc4787a3690a9ed3b8b2c6c294fc4a6915436b6f7 (patch)
tree15b185728ed6e46fd93f795780a6266fc42ffd76
parent1208f72ac78960fe5060187761479b2a9a417c1b (diff)
downloadspark-c4787a3690a9ed3b8b2c6c294fc4a6915436b6f7.tar.gz
spark-c4787a3690a9ed3b8b2c6c294fc4a6915436b6f7.tar.bz2
spark-c4787a3690a9ed3b8b2c6c294fc4a6915436b6f7.zip
[SPARK-3194][SQL] Add AttributeSet to fix bugs with invalid comparisons of AttributeReferences
It is common to want to describe sets of attributes that are in various parts of a query plan. However, the semantics of putting `AttributeReference` objects into a standard Scala `Set` result in subtle bugs when references differ cosmetically. For example, with case insensitive resolution it is possible to have two references to the same attribute whose names are not equal. In this PR I introduce a new abstraction, an `AttributeSet`, which performs all comparisons using the globally unique `ExpressionId` instead of case class equality. (There is already a related class, [`AttributeMap`](https://github.com/marmbrus/spark/blob/inMemStats/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala#L32)) This new type of set is used to fix a bug in the optimizer where needed attributes were getting projected away underneath join operators. I also took this opportunity to refactor the expression and query plan base classes. In all but one instance the logic for computing the `references` of an `Expression` were the same. Thus, I moved this logic into the base class. For query plans the semantics of the `references` method were ill defined (is it the references output? or is it those used by expression evaluation? or what?). As a result, this method wasn't really used very much. So, I removed it. TODO: - [x] Finish scala doc for `AttributeSet` - [x] Scan the code for other instances of `Set[Attribute]` and refactor them. - [x] Finish removing `references` from `QueryPlan` Author: Michael Armbrust <michael@databricks.com> Closes #2109 from marmbrus/attributeSets and squashes the following commits: 1c0dae5 [Michael Armbrust] work on serialization bug. 9ba868d [Michael Armbrust] Merge remote-tracking branch 'origin/master' into attributeSets 3ae5288 [Michael Armbrust] review comments 40ce7f6 [Michael Armbrust] style d577cc7 [Michael Armbrust] Scaladoc cae5d22 [Michael Armbrust] remove more references implementations d6e16be [Michael Armbrust] Remove more instances of "def references" and normal sets of attributes. fc26b49 [Michael Armbrust] Add AttributeSet class, remove references from Expression.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala106
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala25
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala29
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala8
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala5
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala9
35 files changed, 166 insertions, 123 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index c18d7858f0..4a95240741 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -132,7 +132,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved =>
val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
val resolved = unresolved.flatMap(child.resolveChildren)
- val requiredAttributes = resolved.collect { case a: Attribute => a }.toSet
+ val requiredAttributes = AttributeSet(resolved.collect { case a: Attribute => a })
val missingInProject = requiredAttributes -- p.output
if (missingInProject.nonEmpty) {
@@ -152,8 +152,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
)
logDebug(s"Grouping expressions: $groupingRelation")
- val resolved = unresolved.flatMap(groupingRelation.resolve).toSet
- val missingInAggs = resolved -- a.outputSet
+ val resolved = unresolved.flatMap(groupingRelation.resolve)
+ val missingInAggs = resolved.filterNot(a.outputSet.contains)
logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs")
if (missingInAggs.nonEmpty) {
// Add missing grouping exprs and then project them away after the sort.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index a0e25775da..a2c61c6548 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -66,7 +66,6 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E
override def dataType = throw new UnresolvedException(this, "dataType")
override def foldable = throw new UnresolvedException(this, "foldable")
override def nullable = throw new UnresolvedException(this, "nullable")
- override def references = children.flatMap(_.references).toSet
override lazy val resolved = false
// Unresolved functions are transient at compile time and don't get evaluated during execution.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
new file mode 100644
index 0000000000..c3a08bbdb6
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+protected class AttributeEquals(val a: Attribute) {
+ override def hashCode() = a.exprId.hashCode()
+ override def equals(other: Any) = other match {
+ case otherReference: AttributeEquals => a.exprId == otherReference.a.exprId
+ case otherAttribute => false
+ }
+}
+
+object AttributeSet {
+ /** Constructs a new [[AttributeSet]] given a sequence of [[Attribute Attributes]]. */
+ def apply(baseSet: Seq[Attribute]) = {
+ new AttributeSet(baseSet.map(new AttributeEquals(_)).toSet)
+ }
+}
+
+/**
+ * A Set designed to hold [[AttributeReference]] objects, that performs equality checking using
+ * expression id instead of standard java equality. Using expression id means that these
+ * sets will correctly test for membership, even when the AttributeReferences in question differ
+ * cosmetically (e.g., the names have different capitalizations).
+ *
+ * Note that we do not override equality for Attribute references as it is really weird when
+ * `AttributeReference("a"...) == AttrributeReference("b", ...)`. This tactic leads to broken tests,
+ * and also makes doing transformations hard (we always try keep older trees instead of new ones
+ * when the transformation was a no-op).
+ */
+class AttributeSet private (val baseSet: Set[AttributeEquals])
+ extends Traversable[Attribute] with Serializable {
+
+ /** Returns true if the members of this AttributeSet and other are the same. */
+ override def equals(other: Any) = other match {
+ case otherSet: AttributeSet => baseSet.map(_.a).forall(otherSet.contains)
+ case _ => false
+ }
+
+ /** Returns true if this set contains an Attribute with the same expression id as `elem` */
+ def contains(elem: NamedExpression): Boolean =
+ baseSet.contains(new AttributeEquals(elem.toAttribute))
+
+ /** Returns a new [[AttributeSet]] that contains `elem` in addition to the current elements. */
+ def +(elem: Attribute): AttributeSet = // scalastyle:ignore
+ new AttributeSet(baseSet + new AttributeEquals(elem))
+
+ /** Returns a new [[AttributeSet]] that does not contain `elem`. */
+ def -(elem: Attribute): AttributeSet =
+ new AttributeSet(baseSet - new AttributeEquals(elem))
+
+ /** Returns an iterator containing all of the attributes in the set. */
+ def iterator: Iterator[Attribute] = baseSet.map(_.a).iterator
+
+ /**
+ * Returns true if the [[Attribute Attributes]] in this set are a subset of the Attributes in
+ * `other`.
+ */
+ def subsetOf(other: AttributeSet) = baseSet.subsetOf(other.baseSet)
+
+ /**
+ * Returns a new [[AttributeSet]] that does not contain any of the [[Attribute Attributes]] found
+ * in `other`.
+ */
+ def --(other: Traversable[NamedExpression]) =
+ new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute)))
+
+ /**
+ * Returns a new [[AttributeSet]] that contains all of the [[Attribute Attributes]] found
+ * in `other`.
+ */
+ def ++(other: AttributeSet) = new AttributeSet(baseSet ++ other.baseSet)
+
+ /**
+ * Returns a new [[AttributeSet]] contain only the [[Attribute Attributes]] where `f` evaluates to
+ * true.
+ */
+ override def filter(f: Attribute => Boolean) = new AttributeSet(baseSet.filter(ae => f(ae.a)))
+
+ /**
+ * Returns a new [[AttributeSet]] that only contains [[Attribute Attributes]] that are found in
+ * `this` and `other`.
+ */
+ def intersect(other: AttributeSet) = new AttributeSet(baseSet.intersect(other.baseSet))
+
+ override def foreach[U](f: (Attribute) => U): Unit = baseSet.map(_.a).foreach(f)
+
+ // We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all
+ // sorts of things in its closure.
+ override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 0913f15888..54c6baf1af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -32,8 +32,6 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
type EvaluatedType = Any
- override def references = Set.empty
-
override def toString = s"input[$ordinal]"
override def eval(input: Row): Any = input(ordinal)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index ba62dabe3d..70507e7ee2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -41,7 +41,7 @@ abstract class Expression extends TreeNode[Expression] {
*/
def foldable: Boolean = false
def nullable: Boolean
- def references: Set[Attribute]
+ def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator))
/** Returns the result of evaluating this expression on a given input Row */
def eval(input: Row = null): EvaluatedType
@@ -230,8 +230,6 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
override def foldable = left.foldable && right.foldable
- override def references = left.references ++ right.references
-
override def toString = s"($left $symbol $right)"
}
@@ -242,5 +240,5 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]
abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
self: Product =>
- override def references = child.references
+
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
index 38f836f0a1..851db95b91 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
@@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.types.DoubleType
case object Rand extends LeafExpression {
override def dataType = DoubleType
override def nullable = false
- override def references = Set.empty
private[this] lazy val rand = new Random
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
index 95633dd0c9..63ac2a608b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
@@ -24,7 +24,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
type EvaluatedType = Any
- def references = children.flatMap(_.references).toSet
def nullable = true
/** This method has been generated by this script
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index d2b7685e73..d00b2ac097 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -31,7 +31,6 @@ case object Descending extends SortDirection
case class SortOrder(child: Expression, direction: SortDirection) extends Expression
with trees.UnaryNode[Expression] {
- override def references = child.references
override def dataType = child.dataType
override def nullable = child.nullable
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
index eb8898900d..1eb5571579 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
@@ -35,7 +35,7 @@ case class WrapDynamic(children: Seq[Attribute]) extends Expression {
type EvaluatedType = DynamicRow
def nullable = false
- def references = children.toSet
+
def dataType = DynamicType
override def eval(input: Row): DynamicRow = input match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 613b87ca98..dbc0c2965a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -78,7 +78,7 @@ abstract class AggregateFunction
/** Base should return the generic aggregate expression that this function is computing */
val base: AggregateExpression
- override def references = base.references
+
override def nullable = base.nullable
override def dataType = base.dataType
@@ -89,7 +89,7 @@ abstract class AggregateFunction
}
case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def references = child.references
+
override def nullable = true
override def dataType = child.dataType
override def toString = s"MIN($child)"
@@ -119,7 +119,7 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr
}
case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def references = child.references
+
override def nullable = true
override def dataType = child.dataType
override def toString = s"MAX($child)"
@@ -149,7 +149,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr
}
case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def references = child.references
+
override def nullable = false
override def dataType = LongType
override def toString = s"COUNT($child)"
@@ -166,7 +166,7 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate
def this() = this(null)
override def children = expressions
- override def references = expressions.flatMap(_.references).toSet
+
override def nullable = false
override def dataType = LongType
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})"
@@ -184,7 +184,6 @@ case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpress
def this() = this(null)
override def children = expressions
- override def references = expressions.flatMap(_.references).toSet
override def nullable = false
override def dataType = ArrayType(expressions.head.dataType)
override def toString = s"AddToHashSet(${expressions.mkString(",")})"
@@ -219,7 +218,6 @@ case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression
def this() = this(null)
override def children = inputSet :: Nil
- override def references = inputSet.references
override def nullable = false
override def dataType = LongType
override def toString = s"CombineAndCount($inputSet)"
@@ -248,7 +246,7 @@ case class CombineSetsAndCountFunction(
case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
extends AggregateExpression with trees.UnaryNode[Expression] {
- override def references = child.references
+
override def nullable = false
override def dataType = child.dataType
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
@@ -257,7 +255,7 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
extends AggregateExpression with trees.UnaryNode[Expression] {
- override def references = child.references
+
override def nullable = false
override def dataType = LongType
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
@@ -266,7 +264,7 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
extends PartialAggregate with trees.UnaryNode[Expression] {
- override def references = child.references
+
override def nullable = false
override def dataType = LongType
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
@@ -284,7 +282,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
}
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def references = child.references
+
override def nullable = false
override def dataType = DoubleType
override def toString = s"AVG($child)"
@@ -304,7 +302,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
}
case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def references = child.references
+
override def nullable = false
override def dataType = child.dataType
override def toString = s"SUM($child)"
@@ -322,7 +320,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
case class SumDistinct(child: Expression)
extends AggregateExpression with trees.UnaryNode[Expression] {
- override def references = child.references
+
override def nullable = false
override def dataType = child.dataType
override def toString = s"SUM(DISTINCT $child)"
@@ -331,7 +329,6 @@ case class SumDistinct(child: Expression)
}
case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def references = child.references
override def nullable = true
override def dataType = child.dataType
override def toString = s"FIRST($child)"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 5f8b6ae10f..aae86a3628 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -95,8 +95,6 @@ case class MaxOf(left: Expression, right: Expression) extends Expression {
override def children = left :: right :: Nil
- override def references = left.references ++ right.references
-
override def dataType = left.dataType
override def eval(input: Row): Any = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
index c1154eb81c..dafd745ec9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
@@ -31,7 +31,7 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
/** `Null` is returned for invalid ordinals. */
override def nullable = true
override def foldable = child.foldable && ordinal.foldable
- override def references = children.flatMap(_.references).toSet
+
def dataType = child.dataType match {
case ArrayType(dt, _) => dt
case MapType(_, vt, _) => vt
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index e99c5b452d..9c865254e0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -47,8 +47,6 @@ abstract class Generator extends Expression {
override def nullable = false
- override def references = children.flatMap(_.references).toSet
-
/**
* Should be overridden by specific generators. Called only once for each instance to ensure
* that rule application does not change the output schema of a generator.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index e15e16d633..a8c2396d62 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -52,7 +52,7 @@ case class Literal(value: Any, dataType: DataType) extends LeafExpression {
override def foldable = true
def nullable = value == null
- def references = Set.empty
+
override def toString = if (value != null) value.toString else "null"
@@ -66,8 +66,6 @@ case class MutableLiteral(var value: Any, nullable: Boolean = true) extends Leaf
val dataType = Literal(value).dataType
- def references = Set.empty
-
def update(expression: Expression, input: Row) = {
value = expression.eval(input)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 02d0476262..7c4b9d4847 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -62,7 +62,7 @@ abstract class Attribute extends NamedExpression {
def toAttribute = this
def newInstance: Attribute
- override def references = Set(this)
+
}
/**
@@ -85,7 +85,7 @@ case class Alias(child: Expression, name: String)
override def dataType = child.dataType
override def nullable = child.nullable
- override def references = child.references
+
override def toAttribute = {
if (resolved) {
@@ -116,6 +116,8 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
(val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil)
extends Attribute with trees.LeafNode[Expression] {
+ override def references = AttributeSet(this :: Nil)
+
override def equals(other: Any) = other match {
case ar: AttributeReference => exprId == ar.exprId && dataType == ar.dataType
case _ => false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
index e88c5d4fa1..086d0a3e07 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
@@ -26,7 +26,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
/** Coalesce is nullable if all of its children are nullable, or if it has no children. */
def nullable = !children.exists(!_.nullable)
- def references = children.flatMap(_.references).toSet
// Coalesce is foldable if all children are foldable.
override def foldable = !children.exists(!_.foldable)
@@ -53,7 +52,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}
case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
- def references = child.references
override def foldable = child.foldable
def nullable = false
@@ -65,7 +63,6 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr
}
case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
- def references = child.references
override def foldable = child.foldable
def nullable = false
override def toString = s"IS NOT NULL $child"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 5976b0ddf3..1313ccd120 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -85,7 +85,7 @@ case class Not(child: Expression) extends UnaryExpression with Predicate {
*/
case class In(value: Expression, list: Seq[Expression]) extends Predicate {
def children = value +: list
- def references = children.flatMap(_.references).toSet
+
def nullable = true // TODO: Figure out correct nullability semantics of IN.
override def toString = s"$value IN ${list.mkString("(", ",", ")")}"
@@ -197,7 +197,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
def children = predicate :: trueValue :: falseValue :: Nil
override def nullable = trueValue.nullable || falseValue.nullable
- def references = children.flatMap(_.references).toSet
+
override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType
def dataType = {
if (!resolved) {
@@ -239,7 +239,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
case class CaseWhen(branches: Seq[Expression]) extends Expression {
type EvaluatedType = Any
def children = branches
- def references = children.flatMap(_.references).toSet
+
def dataType = {
if (!resolved) {
throw new UnresolvedException(this, "cannot resolve due to differing types in some branches")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
index e6c570b47b..3d4c4a8853 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
@@ -26,8 +26,6 @@ import org.apache.spark.util.collection.OpenHashSet
case class NewSet(elementType: DataType) extends LeafExpression {
type EvaluatedType = Any
- def references = Set.empty
-
def nullable = false
// We are currently only using these Expressions internally for aggregation. However, if we ever
@@ -53,9 +51,6 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
def nullable = set.nullable
def dataType = set.dataType
-
- def references = (item.flatMap(_.references) ++ set.flatMap(_.references)).toSet
-
def eval(input: Row): Any = {
val itemEval = item.eval(input)
val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 97fc3a3b14..c2a3a5ca3c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -226,8 +226,6 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
if (str.dataType == BinaryType) str.dataType else StringType
}
- def references = children.flatMap(_.references).toSet
-
override def children = str :: pos :: len :: Nil
@inline
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 5f86d6047c..ddd4b3755d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -65,8 +65,10 @@ object ColumnPruning extends Rule[LogicalPlan] {
// Eliminate unneeded attributes from either side of a Join.
case Project(projectList, Join(left, right, joinType, condition)) =>
// Collect the list of all references required either above or to evaluate the condition.
- val allReferences: Set[Attribute] =
- projectList.flatMap(_.references).toSet ++ condition.map(_.references).getOrElse(Set.empty)
+ val allReferences: AttributeSet =
+ AttributeSet(
+ projectList.flatMap(_.references.iterator)) ++
+ condition.map(_.references).getOrElse(AttributeSet(Seq.empty))
/** Applies a projection only when the child is producing unnecessary attributes */
def pruneJoinChild(c: LogicalPlan) = prunedChild(c, allReferences)
@@ -76,8 +78,8 @@ object ColumnPruning extends Rule[LogicalPlan] {
// Eliminate unneeded attributes from right side of a LeftSemiJoin.
case Join(left, right, LeftSemi, condition) =>
// Collect the list of all references required to evaluate the condition.
- val allReferences: Set[Attribute] =
- condition.map(_.references).getOrElse(Set.empty)
+ val allReferences: AttributeSet =
+ condition.map(_.references).getOrElse(AttributeSet(Seq.empty))
Join(left, prunedChild(right, allReferences), LeftSemi, condition)
@@ -104,7 +106,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
}
/** Applies a projection only when the child is producing unnecessary attributes */
- private def prunedChild(c: LogicalPlan, allReferences: Set[Attribute]) =
+ private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) {
Project(allReferences.filter(c.outputSet.contains).toSeq, c)
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 0988b0c6d9..1e177e28f8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.plans
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression}
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.types.{ArrayType, DataType, StructField, StructType}
@@ -29,7 +29,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
/**
* Returns the set of attributes that are output by this node.
*/
- def outputSet: Set[Attribute] = output.toSet
+ def outputSet: AttributeSet = AttributeSet(output)
/**
* Runs [[transform]] with `rule` on all expressions present in this query operator.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 278569f0cb..8616ac45b0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -46,16 +46,10 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
)
/**
- * Returns the set of attributes that are referenced by this node
- * during evaluation.
- */
- def references: Set[Attribute]
-
- /**
* Returns the set of attributes that this node takes as
* input from its children.
*/
- lazy val inputSet: Set[Attribute] = children.flatMap(_.output).toSet
+ lazy val inputSet: AttributeSet = AttributeSet(children.flatMap(_.output))
/**
* Returns true if this expression and all its children have been resolved to a specific schema
@@ -126,9 +120,6 @@ abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] {
override lazy val statistics: Statistics =
throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.")
-
- // Leaf nodes by definition cannot reference any input attributes.
- override def references = Set.empty
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
index d3f9d0fb93..4460c86ed9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
@@ -30,6 +30,4 @@ case class ScriptTransformation(
input: Seq[Expression],
script: String,
output: Seq[Attribute],
- child: LogicalPlan) extends UnaryNode {
- def references = input.flatMap(_.references).toSet
-}
+ child: LogicalPlan) extends UnaryNode
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 3cb407217c..4adfb18937 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.types._
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
def output = projectList.map(_.toAttribute)
- def references = projectList.flatMap(_.references).toSet
}
/**
@@ -59,14 +58,10 @@ case class Generate(
override def output =
if (join) child.output ++ generatorOutput else generatorOutput
-
- override def references =
- if (join) child.outputSet else generator.references
}
case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
override def output = child.output
- override def references = condition.references
}
case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
@@ -76,8 +71,6 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
override lazy val resolved =
childrenResolved &&
!left.output.zip(right.output).exists { case (l,r) => l.dataType != r.dataType }
-
- override def references = Set.empty
}
case class Join(
@@ -86,8 +79,6 @@ case class Join(
joinType: JoinType,
condition: Option[Expression]) extends BinaryNode {
- override def references = condition.map(_.references).getOrElse(Set.empty)
-
override def output = {
joinType match {
case LeftSemi =>
@@ -106,8 +97,6 @@ case class Join(
case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
def output = left.output
-
- def references = Set.empty
}
case class InsertIntoTable(
@@ -118,7 +107,6 @@ case class InsertIntoTable(
extends LogicalPlan {
// The table being inserted into is a child for the purposes of transformations.
override def children = table :: child :: Nil
- override def references = Set.empty
override def output = child.output
override lazy val resolved = childrenResolved && child.output.zip(table.output).forall {
@@ -130,20 +118,17 @@ case class InsertIntoCreatedTable(
databaseName: Option[String],
tableName: String,
child: LogicalPlan) extends UnaryNode {
- override def references = Set.empty
override def output = child.output
}
case class WriteToFile(
path: String,
child: LogicalPlan) extends UnaryNode {
- override def references = Set.empty
override def output = child.output
}
case class Sort(order: Seq[SortOrder], child: LogicalPlan) extends UnaryNode {
override def output = child.output
- override def references = order.flatMap(_.references).toSet
}
case class Aggregate(
@@ -152,19 +137,20 @@ case class Aggregate(
child: LogicalPlan)
extends UnaryNode {
+ /** The set of all AttributeReferences required for this aggregation. */
+ def references =
+ AttributeSet(
+ groupingExpressions.flatMap(_.references) ++ aggregateExpressions.flatMap(_.references))
+
override def output = aggregateExpressions.map(_.toAttribute)
- override def references =
- (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet
}
case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
override def output = child.output
- override def references = limitExpr.references
}
case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
override def output = child.output.map(_.withQualifiers(alias :: Nil))
- override def references = Set.empty
}
/**
@@ -191,20 +177,16 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode {
a.qualifiers)
case other => other
}
-
- override def references = Set.empty
}
case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan)
extends UnaryNode {
override def output = child.output
- override def references = Set.empty
}
case class Distinct(child: LogicalPlan) extends UnaryNode {
override def output = child.output
- override def references = child.outputSet
}
case object NoRelation extends LeafNode {
@@ -213,5 +195,4 @@ case object NoRelation extends LeafNode {
case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
override def output = left.output
- override def references = Set.empty
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
index 7146fbd540..72b0c5c8e7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
@@ -31,13 +31,9 @@ abstract class RedistributeData extends UnaryNode {
case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan)
extends RedistributeData {
-
- def references = sortExpressions.flatMap(_.references).toSet
}
case class Repartition(partitionExpressions: Seq[Expression], child: LogicalPlan)
extends RedistributeData {
-
- def references = partitionExpressions.flatMap(_.references).toSet
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 4bb022cf23..ccb0df113c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -71,6 +71,7 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
"An AllTuples should be used to represent a distribution that only has " +
"a single partition.")
+ // TODO: This is not really valid...
def clustering = ordering.map(_.child).toSet
}
@@ -139,7 +140,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
with Partitioning {
override def children = expressions
- override def references = expressions.flatMap(_.references).toSet
override def nullable = false
override def dataType = IntegerType
@@ -179,7 +179,6 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
with Partitioning {
override def children = ordering
- override def references = ordering.flatMap(_.references).toSet
override def nullable = false
override def dataType = IntegerType
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 6344874538..296202543e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.types.{StringType, NullType}
case class Dummy(optKey: Option[Expression]) extends Expression {
def children = optKey.toSeq
- def references = Set.empty[Attribute]
def nullable = true
def dataType = NullType
override lazy val resolved = true
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 8a9f4deb6a..6f0eed3f63 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -344,8 +344,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
prunePushedDownFilters: Seq[Expression] => Seq[Expression],
scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = {
- val projectSet = projectList.flatMap(_.references).toSet
- val filterSet = filterPredicates.flatMap(_.references).toSet
+ val projectSet = AttributeSet(projectList.flatMap(_.references))
+ val filterSet = AttributeSet(filterPredicates.flatMap(_.references))
val filterCondition = prunePushedDownFilters(filterPredicates).reduceLeftOption(And)
// Right now we still use a projection even if the only evaluation is applying an alias
@@ -354,7 +354,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
// TODO: Decouple final output schema from expression evaluation so this copy can be
// avoided safely.
- if (projectList.toSet == projectSet && filterSet.subsetOf(projectSet)) {
+ if (AttributeSet(projectList.map(_.toAttribute)) == projectSet &&
+ filterSet.subsetOf(projectSet)) {
// When it is possible to just use column pruning to get the right projection and
// when the columns of this projection are enough to evaluate all filter conditions,
// just do a scan followed by a filter, with no extra project.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index e63b490304..24e88eea31 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -79,8 +79,6 @@ private[sql] case class InMemoryRelation(
override def children = Seq.empty
- override def references = Set.empty
-
override def newInstance() = {
new InMemoryRelation(
output.map(_.newInstance),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 21cbbc9772..7d33ea5b02 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -141,10 +141,9 @@ case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQ
extends LogicalPlan with MultiInstanceRelation {
def output = alreadyPlanned.output
- override def references = Set.empty
override def children = Nil
- override final def newInstance: this.type = {
+ override final def newInstance(): this.type = {
SparkLogicalPlan(
alreadyPlanned match {
case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index f31df05182..5b896c55b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -58,8 +58,6 @@ package object debug {
}
private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode {
- def references = Set.empty
-
def output = child.output
implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index b92091b560..aef6ebf86b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -49,7 +49,6 @@ private[spark] case class PythonUDF(
override def toString = s"PythonUDF#$name(${children.mkString(",")})"
def nullable: Boolean = true
- def references: Set[Attribute] = children.flatMap(_.references).toSet
override def eval(input: Row) = sys.error("PythonUDFs can not be directly evaluated.")
}
@@ -113,7 +112,6 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] {
case class EvaluatePython(udf: PythonUDF, child: LogicalPlan) extends logical.UnaryNode {
val resultAttribute = AttributeReference("pythonUDF", udf.dataType, nullable=true)()
- def references = Set.empty
def output = child.output :+ resultAttribute
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 389ace726d..10fa8314c9 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -79,9 +79,9 @@ private[hive] trait HiveStrategies {
hiveContext.convertMetastoreParquet =>
// Filter out all predicates that only deal with partition keys
- val partitionKeyIds = relation.partitionKeys.map(_.exprId).toSet
+ val partitionsKeys = AttributeSet(relation.partitionKeys)
val (pruningPredicates, otherPredicates) = predicates.partition {
- _.references.map(_.exprId).subsetOf(partitionKeyIds)
+ _.references.subsetOf(partitionsKeys)
}
// We are going to throw the predicates and projection back at the whole optimization
@@ -176,9 +176,9 @@ private[hive] trait HiveStrategies {
case PhysicalOperation(projectList, predicates, relation: MetastoreRelation) =>
// Filter out all predicates that only deal with partition keys, these are given to the
// hive table scan operator to be used for partition pruning.
- val partitionKeyIds = relation.partitionKeys.map(_.exprId).toSet
+ val partitionKeyIds = AttributeSet(relation.partitionKeys)
val (pruningPredicates, otherPredicates) = predicates.partition {
- _.references.map(_.exprId).subsetOf(partitionKeyIds)
+ _.references.subsetOf(partitionKeyIds)
}
pruneFilterProject(
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index c6497a15ef..7d1ad53d8b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -88,7 +88,6 @@ private[hive] abstract class HiveUdf extends Expression with Logging with HiveFu
type EvaluatedType = Any
def nullable = true
- def references = children.flatMap(_.references).toSet
lazy val function = createFunction[UDFType]()
@@ -229,8 +228,6 @@ private[hive] case class HiveGenericUdaf(
def nullable: Boolean = true
- def references: Set[Attribute] = children.map(_.references).flatten.toSet
-
override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
def newInstance() = new HiveUdafFunction(functionClassName, children, this)
@@ -253,8 +250,6 @@ private[hive] case class HiveGenericUdtf(
children: Seq[Expression])
extends Generator with HiveInspectors with HiveFunctionFactory {
- override def references = children.flatMap(_.references).toSet
-
@transient
protected lazy val function: GenericUDTF = createFunction()
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
index 6b3ffd1c0f..b6be6bc1bf 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.hive.execution
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
-case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested])
case class Nested(a: Int, B: Int)
+case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested])
/**
* A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution.
@@ -57,6 +57,13 @@ class HiveResolutionSuite extends HiveComparisonTest {
.registerTempTable("caseSensitivityTest")
sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest")
+
+ println(sql("SELECT * FROM casesensitivitytest one JOIN casesensitivitytest two ON one.a = two.a").queryExecution)
+
+ sql("SELECT * FROM casesensitivitytest one JOIN casesensitivitytest two ON one.a = two.a").collect()
+
+ // TODO: sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a")
+
}
test("nested repeated resolution") {