aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala106
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala25
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala29
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala8
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala5
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala9
35 files changed, 166 insertions, 123 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index c18d7858f0..4a95240741 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -132,7 +132,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved =>
val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
val resolved = unresolved.flatMap(child.resolveChildren)
- val requiredAttributes = resolved.collect { case a: Attribute => a }.toSet
+ val requiredAttributes = AttributeSet(resolved.collect { case a: Attribute => a })
val missingInProject = requiredAttributes -- p.output
if (missingInProject.nonEmpty) {
@@ -152,8 +152,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
)
logDebug(s"Grouping expressions: $groupingRelation")
- val resolved = unresolved.flatMap(groupingRelation.resolve).toSet
- val missingInAggs = resolved -- a.outputSet
+ val resolved = unresolved.flatMap(groupingRelation.resolve)
+ val missingInAggs = resolved.filterNot(a.outputSet.contains)
logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs")
if (missingInAggs.nonEmpty) {
// Add missing grouping exprs and then project them away after the sort.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index a0e25775da..a2c61c6548 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -66,7 +66,6 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E
override def dataType = throw new UnresolvedException(this, "dataType")
override def foldable = throw new UnresolvedException(this, "foldable")
override def nullable = throw new UnresolvedException(this, "nullable")
- override def references = children.flatMap(_.references).toSet
override lazy val resolved = false
// Unresolved functions are transient at compile time and don't get evaluated during execution.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
new file mode 100644
index 0000000000..c3a08bbdb6
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+protected class AttributeEquals(val a: Attribute) {
+ override def hashCode() = a.exprId.hashCode()
+ override def equals(other: Any) = other match {
+ case otherReference: AttributeEquals => a.exprId == otherReference.a.exprId
+ case otherAttribute => false
+ }
+}
+
+object AttributeSet {
+ /** Constructs a new [[AttributeSet]] given a sequence of [[Attribute Attributes]]. */
+ def apply(baseSet: Seq[Attribute]) = {
+ new AttributeSet(baseSet.map(new AttributeEquals(_)).toSet)
+ }
+}
+
+/**
+ * A Set designed to hold [[AttributeReference]] objects, that performs equality checking using
+ * expression id instead of standard java equality. Using expression id means that these
+ * sets will correctly test for membership, even when the AttributeReferences in question differ
+ * cosmetically (e.g., the names have different capitalizations).
+ *
+ * Note that we do not override equality for Attribute references as it is really weird when
+ * `AttributeReference("a"...) == AttrributeReference("b", ...)`. This tactic leads to broken tests,
+ * and also makes doing transformations hard (we always try keep older trees instead of new ones
+ * when the transformation was a no-op).
+ */
+class AttributeSet private (val baseSet: Set[AttributeEquals])
+ extends Traversable[Attribute] with Serializable {
+
+ /** Returns true if the members of this AttributeSet and other are the same. */
+ override def equals(other: Any) = other match {
+ case otherSet: AttributeSet => baseSet.map(_.a).forall(otherSet.contains)
+ case _ => false
+ }
+
+ /** Returns true if this set contains an Attribute with the same expression id as `elem` */
+ def contains(elem: NamedExpression): Boolean =
+ baseSet.contains(new AttributeEquals(elem.toAttribute))
+
+ /** Returns a new [[AttributeSet]] that contains `elem` in addition to the current elements. */
+ def +(elem: Attribute): AttributeSet = // scalastyle:ignore
+ new AttributeSet(baseSet + new AttributeEquals(elem))
+
+ /** Returns a new [[AttributeSet]] that does not contain `elem`. */
+ def -(elem: Attribute): AttributeSet =
+ new AttributeSet(baseSet - new AttributeEquals(elem))
+
+ /** Returns an iterator containing all of the attributes in the set. */
+ def iterator: Iterator[Attribute] = baseSet.map(_.a).iterator
+
+ /**
+ * Returns true if the [[Attribute Attributes]] in this set are a subset of the Attributes in
+ * `other`.
+ */
+ def subsetOf(other: AttributeSet) = baseSet.subsetOf(other.baseSet)
+
+ /**
+ * Returns a new [[AttributeSet]] that does not contain any of the [[Attribute Attributes]] found
+ * in `other`.
+ */
+ def --(other: Traversable[NamedExpression]) =
+ new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute)))
+
+ /**
+ * Returns a new [[AttributeSet]] that contains all of the [[Attribute Attributes]] found
+ * in `other`.
+ */
+ def ++(other: AttributeSet) = new AttributeSet(baseSet ++ other.baseSet)
+
+ /**
+ * Returns a new [[AttributeSet]] contain only the [[Attribute Attributes]] where `f` evaluates to
+ * true.
+ */
+ override def filter(f: Attribute => Boolean) = new AttributeSet(baseSet.filter(ae => f(ae.a)))
+
+ /**
+ * Returns a new [[AttributeSet]] that only contains [[Attribute Attributes]] that are found in
+ * `this` and `other`.
+ */
+ def intersect(other: AttributeSet) = new AttributeSet(baseSet.intersect(other.baseSet))
+
+ override def foreach[U](f: (Attribute) => U): Unit = baseSet.map(_.a).foreach(f)
+
+ // We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all
+ // sorts of things in its closure.
+ override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 0913f15888..54c6baf1af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -32,8 +32,6 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
type EvaluatedType = Any
- override def references = Set.empty
-
override def toString = s"input[$ordinal]"
override def eval(input: Row): Any = input(ordinal)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index ba62dabe3d..70507e7ee2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -41,7 +41,7 @@ abstract class Expression extends TreeNode[Expression] {
*/
def foldable: Boolean = false
def nullable: Boolean
- def references: Set[Attribute]
+ def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator))
/** Returns the result of evaluating this expression on a given input Row */
def eval(input: Row = null): EvaluatedType
@@ -230,8 +230,6 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
override def foldable = left.foldable && right.foldable
- override def references = left.references ++ right.references
-
override def toString = s"($left $symbol $right)"
}
@@ -242,5 +240,5 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]
abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
self: Product =>
- override def references = child.references
+
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
index 38f836f0a1..851db95b91 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
@@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.types.DoubleType
case object Rand extends LeafExpression {
override def dataType = DoubleType
override def nullable = false
- override def references = Set.empty
private[this] lazy val rand = new Random
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
index 95633dd0c9..63ac2a608b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
@@ -24,7 +24,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
type EvaluatedType = Any
- def references = children.flatMap(_.references).toSet
def nullable = true
/** This method has been generated by this script
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index d2b7685e73..d00b2ac097 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -31,7 +31,6 @@ case object Descending extends SortDirection
case class SortOrder(child: Expression, direction: SortDirection) extends Expression
with trees.UnaryNode[Expression] {
- override def references = child.references
override def dataType = child.dataType
override def nullable = child.nullable
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
index eb8898900d..1eb5571579 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
@@ -35,7 +35,7 @@ case class WrapDynamic(children: Seq[Attribute]) extends Expression {
type EvaluatedType = DynamicRow
def nullable = false
- def references = children.toSet
+
def dataType = DynamicType
override def eval(input: Row): DynamicRow = input match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 613b87ca98..dbc0c2965a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -78,7 +78,7 @@ abstract class AggregateFunction
/** Base should return the generic aggregate expression that this function is computing */
val base: AggregateExpression
- override def references = base.references
+
override def nullable = base.nullable
override def dataType = base.dataType
@@ -89,7 +89,7 @@ abstract class AggregateFunction
}
case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def references = child.references
+
override def nullable = true
override def dataType = child.dataType
override def toString = s"MIN($child)"
@@ -119,7 +119,7 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr
}
case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def references = child.references
+
override def nullable = true
override def dataType = child.dataType
override def toString = s"MAX($child)"
@@ -149,7 +149,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr
}
case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def references = child.references
+
override def nullable = false
override def dataType = LongType
override def toString = s"COUNT($child)"
@@ -166,7 +166,7 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate
def this() = this(null)
override def children = expressions
- override def references = expressions.flatMap(_.references).toSet
+
override def nullable = false
override def dataType = LongType
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})"
@@ -184,7 +184,6 @@ case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpress
def this() = this(null)
override def children = expressions
- override def references = expressions.flatMap(_.references).toSet
override def nullable = false
override def dataType = ArrayType(expressions.head.dataType)
override def toString = s"AddToHashSet(${expressions.mkString(",")})"
@@ -219,7 +218,6 @@ case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression
def this() = this(null)
override def children = inputSet :: Nil
- override def references = inputSet.references
override def nullable = false
override def dataType = LongType
override def toString = s"CombineAndCount($inputSet)"
@@ -248,7 +246,7 @@ case class CombineSetsAndCountFunction(
case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
extends AggregateExpression with trees.UnaryNode[Expression] {
- override def references = child.references
+
override def nullable = false
override def dataType = child.dataType
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
@@ -257,7 +255,7 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
extends AggregateExpression with trees.UnaryNode[Expression] {
- override def references = child.references
+
override def nullable = false
override def dataType = LongType
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
@@ -266,7 +264,7 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
extends PartialAggregate with trees.UnaryNode[Expression] {
- override def references = child.references
+
override def nullable = false
override def dataType = LongType
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
@@ -284,7 +282,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
}
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def references = child.references
+
override def nullable = false
override def dataType = DoubleType
override def toString = s"AVG($child)"
@@ -304,7 +302,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
}
case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def references = child.references
+
override def nullable = false
override def dataType = child.dataType
override def toString = s"SUM($child)"
@@ -322,7 +320,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
case class SumDistinct(child: Expression)
extends AggregateExpression with trees.UnaryNode[Expression] {
- override def references = child.references
+
override def nullable = false
override def dataType = child.dataType
override def toString = s"SUM(DISTINCT $child)"
@@ -331,7 +329,6 @@ case class SumDistinct(child: Expression)
}
case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def references = child.references
override def nullable = true
override def dataType = child.dataType
override def toString = s"FIRST($child)"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 5f8b6ae10f..aae86a3628 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -95,8 +95,6 @@ case class MaxOf(left: Expression, right: Expression) extends Expression {
override def children = left :: right :: Nil
- override def references = left.references ++ right.references
-
override def dataType = left.dataType
override def eval(input: Row): Any = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
index c1154eb81c..dafd745ec9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
@@ -31,7 +31,7 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
/** `Null` is returned for invalid ordinals. */
override def nullable = true
override def foldable = child.foldable && ordinal.foldable
- override def references = children.flatMap(_.references).toSet
+
def dataType = child.dataType match {
case ArrayType(dt, _) => dt
case MapType(_, vt, _) => vt
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index e99c5b452d..9c865254e0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -47,8 +47,6 @@ abstract class Generator extends Expression {
override def nullable = false
- override def references = children.flatMap(_.references).toSet
-
/**
* Should be overridden by specific generators. Called only once for each instance to ensure
* that rule application does not change the output schema of a generator.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index e15e16d633..a8c2396d62 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -52,7 +52,7 @@ case class Literal(value: Any, dataType: DataType) extends LeafExpression {
override def foldable = true
def nullable = value == null
- def references = Set.empty
+
override def toString = if (value != null) value.toString else "null"
@@ -66,8 +66,6 @@ case class MutableLiteral(var value: Any, nullable: Boolean = true) extends Leaf
val dataType = Literal(value).dataType
- def references = Set.empty
-
def update(expression: Expression, input: Row) = {
value = expression.eval(input)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 02d0476262..7c4b9d4847 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -62,7 +62,7 @@ abstract class Attribute extends NamedExpression {
def toAttribute = this
def newInstance: Attribute
- override def references = Set(this)
+
}
/**
@@ -85,7 +85,7 @@ case class Alias(child: Expression, name: String)
override def dataType = child.dataType
override def nullable = child.nullable
- override def references = child.references
+
override def toAttribute = {
if (resolved) {
@@ -116,6 +116,8 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
(val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil)
extends Attribute with trees.LeafNode[Expression] {
+ override def references = AttributeSet(this :: Nil)
+
override def equals(other: Any) = other match {
case ar: AttributeReference => exprId == ar.exprId && dataType == ar.dataType
case _ => false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
index e88c5d4fa1..086d0a3e07 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
@@ -26,7 +26,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
/** Coalesce is nullable if all of its children are nullable, or if it has no children. */
def nullable = !children.exists(!_.nullable)
- def references = children.flatMap(_.references).toSet
// Coalesce is foldable if all children are foldable.
override def foldable = !children.exists(!_.foldable)
@@ -53,7 +52,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}
case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
- def references = child.references
override def foldable = child.foldable
def nullable = false
@@ -65,7 +63,6 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr
}
case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
- def references = child.references
override def foldable = child.foldable
def nullable = false
override def toString = s"IS NOT NULL $child"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 5976b0ddf3..1313ccd120 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -85,7 +85,7 @@ case class Not(child: Expression) extends UnaryExpression with Predicate {
*/
case class In(value: Expression, list: Seq[Expression]) extends Predicate {
def children = value +: list
- def references = children.flatMap(_.references).toSet
+
def nullable = true // TODO: Figure out correct nullability semantics of IN.
override def toString = s"$value IN ${list.mkString("(", ",", ")")}"
@@ -197,7 +197,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
def children = predicate :: trueValue :: falseValue :: Nil
override def nullable = trueValue.nullable || falseValue.nullable
- def references = children.flatMap(_.references).toSet
+
override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType
def dataType = {
if (!resolved) {
@@ -239,7 +239,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
case class CaseWhen(branches: Seq[Expression]) extends Expression {
type EvaluatedType = Any
def children = branches
- def references = children.flatMap(_.references).toSet
+
def dataType = {
if (!resolved) {
throw new UnresolvedException(this, "cannot resolve due to differing types in some branches")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
index e6c570b47b..3d4c4a8853 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
@@ -26,8 +26,6 @@ import org.apache.spark.util.collection.OpenHashSet
case class NewSet(elementType: DataType) extends LeafExpression {
type EvaluatedType = Any
- def references = Set.empty
-
def nullable = false
// We are currently only using these Expressions internally for aggregation. However, if we ever
@@ -53,9 +51,6 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
def nullable = set.nullable
def dataType = set.dataType
-
- def references = (item.flatMap(_.references) ++ set.flatMap(_.references)).toSet
-
def eval(input: Row): Any = {
val itemEval = item.eval(input)
val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 97fc3a3b14..c2a3a5ca3c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -226,8 +226,6 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
if (str.dataType == BinaryType) str.dataType else StringType
}
- def references = children.flatMap(_.references).toSet
-
override def children = str :: pos :: len :: Nil
@inline
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 5f86d6047c..ddd4b3755d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -65,8 +65,10 @@ object ColumnPruning extends Rule[LogicalPlan] {
// Eliminate unneeded attributes from either side of a Join.
case Project(projectList, Join(left, right, joinType, condition)) =>
// Collect the list of all references required either above or to evaluate the condition.
- val allReferences: Set[Attribute] =
- projectList.flatMap(_.references).toSet ++ condition.map(_.references).getOrElse(Set.empty)
+ val allReferences: AttributeSet =
+ AttributeSet(
+ projectList.flatMap(_.references.iterator)) ++
+ condition.map(_.references).getOrElse(AttributeSet(Seq.empty))
/** Applies a projection only when the child is producing unnecessary attributes */
def pruneJoinChild(c: LogicalPlan) = prunedChild(c, allReferences)
@@ -76,8 +78,8 @@ object ColumnPruning extends Rule[LogicalPlan] {
// Eliminate unneeded attributes from right side of a LeftSemiJoin.
case Join(left, right, LeftSemi, condition) =>
// Collect the list of all references required to evaluate the condition.
- val allReferences: Set[Attribute] =
- condition.map(_.references).getOrElse(Set.empty)
+ val allReferences: AttributeSet =
+ condition.map(_.references).getOrElse(AttributeSet(Seq.empty))
Join(left, prunedChild(right, allReferences), LeftSemi, condition)
@@ -104,7 +106,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
}
/** Applies a projection only when the child is producing unnecessary attributes */
- private def prunedChild(c: LogicalPlan, allReferences: Set[Attribute]) =
+ private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) {
Project(allReferences.filter(c.outputSet.contains).toSeq, c)
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 0988b0c6d9..1e177e28f8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.plans
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression}
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.types.{ArrayType, DataType, StructField, StructType}
@@ -29,7 +29,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
/**
* Returns the set of attributes that are output by this node.
*/
- def outputSet: Set[Attribute] = output.toSet
+ def outputSet: AttributeSet = AttributeSet(output)
/**
* Runs [[transform]] with `rule` on all expressions present in this query operator.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 278569f0cb..8616ac45b0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -46,16 +46,10 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
)
/**
- * Returns the set of attributes that are referenced by this node
- * during evaluation.
- */
- def references: Set[Attribute]
-
- /**
* Returns the set of attributes that this node takes as
* input from its children.
*/
- lazy val inputSet: Set[Attribute] = children.flatMap(_.output).toSet
+ lazy val inputSet: AttributeSet = AttributeSet(children.flatMap(_.output))
/**
* Returns true if this expression and all its children have been resolved to a specific schema
@@ -126,9 +120,6 @@ abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] {
override lazy val statistics: Statistics =
throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.")
-
- // Leaf nodes by definition cannot reference any input attributes.
- override def references = Set.empty
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
index d3f9d0fb93..4460c86ed9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
@@ -30,6 +30,4 @@ case class ScriptTransformation(
input: Seq[Expression],
script: String,
output: Seq[Attribute],
- child: LogicalPlan) extends UnaryNode {
- def references = input.flatMap(_.references).toSet
-}
+ child: LogicalPlan) extends UnaryNode
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 3cb407217c..4adfb18937 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.types._
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
def output = projectList.map(_.toAttribute)
- def references = projectList.flatMap(_.references).toSet
}
/**
@@ -59,14 +58,10 @@ case class Generate(
override def output =
if (join) child.output ++ generatorOutput else generatorOutput
-
- override def references =
- if (join) child.outputSet else generator.references
}
case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
override def output = child.output
- override def references = condition.references
}
case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
@@ -76,8 +71,6 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
override lazy val resolved =
childrenResolved &&
!left.output.zip(right.output).exists { case (l,r) => l.dataType != r.dataType }
-
- override def references = Set.empty
}
case class Join(
@@ -86,8 +79,6 @@ case class Join(
joinType: JoinType,
condition: Option[Expression]) extends BinaryNode {
- override def references = condition.map(_.references).getOrElse(Set.empty)
-
override def output = {
joinType match {
case LeftSemi =>
@@ -106,8 +97,6 @@ case class Join(
case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
def output = left.output
-
- def references = Set.empty
}
case class InsertIntoTable(
@@ -118,7 +107,6 @@ case class InsertIntoTable(
extends LogicalPlan {
// The table being inserted into is a child for the purposes of transformations.
override def children = table :: child :: Nil
- override def references = Set.empty
override def output = child.output
override lazy val resolved = childrenResolved && child.output.zip(table.output).forall {
@@ -130,20 +118,17 @@ case class InsertIntoCreatedTable(
databaseName: Option[String],
tableName: String,
child: LogicalPlan) extends UnaryNode {
- override def references = Set.empty
override def output = child.output
}
case class WriteToFile(
path: String,
child: LogicalPlan) extends UnaryNode {
- override def references = Set.empty
override def output = child.output
}
case class Sort(order: Seq[SortOrder], child: LogicalPlan) extends UnaryNode {
override def output = child.output
- override def references = order.flatMap(_.references).toSet
}
case class Aggregate(
@@ -152,19 +137,20 @@ case class Aggregate(
child: LogicalPlan)
extends UnaryNode {
+ /** The set of all AttributeReferences required for this aggregation. */
+ def references =
+ AttributeSet(
+ groupingExpressions.flatMap(_.references) ++ aggregateExpressions.flatMap(_.references))
+
override def output = aggregateExpressions.map(_.toAttribute)
- override def references =
- (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet
}
case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
override def output = child.output
- override def references = limitExpr.references
}
case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
override def output = child.output.map(_.withQualifiers(alias :: Nil))
- override def references = Set.empty
}
/**
@@ -191,20 +177,16 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode {
a.qualifiers)
case other => other
}
-
- override def references = Set.empty
}
case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan)
extends UnaryNode {
override def output = child.output
- override def references = Set.empty
}
case class Distinct(child: LogicalPlan) extends UnaryNode {
override def output = child.output
- override def references = child.outputSet
}
case object NoRelation extends LeafNode {
@@ -213,5 +195,4 @@ case object NoRelation extends LeafNode {
case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
override def output = left.output
- override def references = Set.empty
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
index 7146fbd540..72b0c5c8e7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
@@ -31,13 +31,9 @@ abstract class RedistributeData extends UnaryNode {
case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan)
extends RedistributeData {
-
- def references = sortExpressions.flatMap(_.references).toSet
}
case class Repartition(partitionExpressions: Seq[Expression], child: LogicalPlan)
extends RedistributeData {
-
- def references = partitionExpressions.flatMap(_.references).toSet
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 4bb022cf23..ccb0df113c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -71,6 +71,7 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
"An AllTuples should be used to represent a distribution that only has " +
"a single partition.")
+ // TODO: This is not really valid...
def clustering = ordering.map(_.child).toSet
}
@@ -139,7 +140,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
with Partitioning {
override def children = expressions
- override def references = expressions.flatMap(_.references).toSet
override def nullable = false
override def dataType = IntegerType
@@ -179,7 +179,6 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
with Partitioning {
override def children = ordering
- override def references = ordering.flatMap(_.references).toSet
override def nullable = false
override def dataType = IntegerType
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 6344874538..296202543e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.types.{StringType, NullType}
case class Dummy(optKey: Option[Expression]) extends Expression {
def children = optKey.toSeq
- def references = Set.empty[Attribute]
def nullable = true
def dataType = NullType
override lazy val resolved = true
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 8a9f4deb6a..6f0eed3f63 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -344,8 +344,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
prunePushedDownFilters: Seq[Expression] => Seq[Expression],
scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = {
- val projectSet = projectList.flatMap(_.references).toSet
- val filterSet = filterPredicates.flatMap(_.references).toSet
+ val projectSet = AttributeSet(projectList.flatMap(_.references))
+ val filterSet = AttributeSet(filterPredicates.flatMap(_.references))
val filterCondition = prunePushedDownFilters(filterPredicates).reduceLeftOption(And)
// Right now we still use a projection even if the only evaluation is applying an alias
@@ -354,7 +354,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
// TODO: Decouple final output schema from expression evaluation so this copy can be
// avoided safely.
- if (projectList.toSet == projectSet && filterSet.subsetOf(projectSet)) {
+ if (AttributeSet(projectList.map(_.toAttribute)) == projectSet &&
+ filterSet.subsetOf(projectSet)) {
// When it is possible to just use column pruning to get the right projection and
// when the columns of this projection are enough to evaluate all filter conditions,
// just do a scan followed by a filter, with no extra project.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index e63b490304..24e88eea31 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -79,8 +79,6 @@ private[sql] case class InMemoryRelation(
override def children = Seq.empty
- override def references = Set.empty
-
override def newInstance() = {
new InMemoryRelation(
output.map(_.newInstance),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 21cbbc9772..7d33ea5b02 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -141,10 +141,9 @@ case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQ
extends LogicalPlan with MultiInstanceRelation {
def output = alreadyPlanned.output
- override def references = Set.empty
override def children = Nil
- override final def newInstance: this.type = {
+ override final def newInstance(): this.type = {
SparkLogicalPlan(
alreadyPlanned match {
case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index f31df05182..5b896c55b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -58,8 +58,6 @@ package object debug {
}
private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode {
- def references = Set.empty
-
def output = child.output
implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index b92091b560..aef6ebf86b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -49,7 +49,6 @@ private[spark] case class PythonUDF(
override def toString = s"PythonUDF#$name(${children.mkString(",")})"
def nullable: Boolean = true
- def references: Set[Attribute] = children.flatMap(_.references).toSet
override def eval(input: Row) = sys.error("PythonUDFs can not be directly evaluated.")
}
@@ -113,7 +112,6 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] {
case class EvaluatePython(udf: PythonUDF, child: LogicalPlan) extends logical.UnaryNode {
val resultAttribute = AttributeReference("pythonUDF", udf.dataType, nullable=true)()
- def references = Set.empty
def output = child.output :+ resultAttribute
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 389ace726d..10fa8314c9 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -79,9 +79,9 @@ private[hive] trait HiveStrategies {
hiveContext.convertMetastoreParquet =>
// Filter out all predicates that only deal with partition keys
- val partitionKeyIds = relation.partitionKeys.map(_.exprId).toSet
+ val partitionsKeys = AttributeSet(relation.partitionKeys)
val (pruningPredicates, otherPredicates) = predicates.partition {
- _.references.map(_.exprId).subsetOf(partitionKeyIds)
+ _.references.subsetOf(partitionsKeys)
}
// We are going to throw the predicates and projection back at the whole optimization
@@ -176,9 +176,9 @@ private[hive] trait HiveStrategies {
case PhysicalOperation(projectList, predicates, relation: MetastoreRelation) =>
// Filter out all predicates that only deal with partition keys, these are given to the
// hive table scan operator to be used for partition pruning.
- val partitionKeyIds = relation.partitionKeys.map(_.exprId).toSet
+ val partitionKeyIds = AttributeSet(relation.partitionKeys)
val (pruningPredicates, otherPredicates) = predicates.partition {
- _.references.map(_.exprId).subsetOf(partitionKeyIds)
+ _.references.subsetOf(partitionKeyIds)
}
pruneFilterProject(
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index c6497a15ef..7d1ad53d8b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -88,7 +88,6 @@ private[hive] abstract class HiveUdf extends Expression with Logging with HiveFu
type EvaluatedType = Any
def nullable = true
- def references = children.flatMap(_.references).toSet
lazy val function = createFunction[UDFType]()
@@ -229,8 +228,6 @@ private[hive] case class HiveGenericUdaf(
def nullable: Boolean = true
- def references: Set[Attribute] = children.map(_.references).flatten.toSet
-
override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
def newInstance() = new HiveUdafFunction(functionClassName, children, this)
@@ -253,8 +250,6 @@ private[hive] case class HiveGenericUdtf(
children: Seq[Expression])
extends Generator with HiveInspectors with HiveFunctionFactory {
- override def references = children.flatMap(_.references).toSet
-
@transient
protected lazy val function: GenericUDTF = createFunction()
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
index 6b3ffd1c0f..b6be6bc1bf 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.hive.execution
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
-case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested])
case class Nested(a: Int, B: Int)
+case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested])
/**
* A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution.
@@ -57,6 +57,13 @@ class HiveResolutionSuite extends HiveComparisonTest {
.registerTempTable("caseSensitivityTest")
sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest")
+
+ println(sql("SELECT * FROM casesensitivitytest one JOIN casesensitivitytest two ON one.a = two.a").queryExecution)
+
+ sql("SELECT * FROM casesensitivitytest one JOIN casesensitivitytest two ON one.a = two.a").collect()
+
+ // TODO: sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a")
+
}
test("nested repeated resolution") {