aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorgatorsmile <gatorsmile@gmail.com>2015-11-16 15:22:12 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-16 15:22:12 -0800
commit75ee12f09c2645c1ad682764d512965f641eb5c2 (patch)
tree8fd6c52cf5cd800e1faac70097c3cec1a6e1fb0b /sql/catalyst
parent31296628ac7cd7be71e0edca335dc8604f62bb47 (diff)
downloadspark-75ee12f09c2645c1ad682764d512965f641eb5c2.tar.gz
spark-75ee12f09c2645c1ad682764d512965f641eb5c2.tar.bz2
spark-75ee12f09c2645c1ad682764d512965f641eb5c2.zip
[SPARK-8658][SQL] AttributeReference's equals method compares all the members
This fix is to change the equals method to check all of the specified fields for equality of AttributeReference. Author: gatorsmile <gatorsmile@gmail.com> Closes #9216 from gatorsmile/namedExpressEqual.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala12
3 files changed, 14 insertions, 12 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index f80bcfcb0b..e3daddace2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -194,7 +194,9 @@ case class AttributeReference(
def sameRef(other: AttributeReference): Boolean = this.exprId == other.exprId
override def equals(other: Any): Boolean = other match {
- case ar: AttributeReference => name == ar.name && exprId == ar.exprId && dataType == ar.dataType
+ case ar: AttributeReference =>
+ name == ar.name && dataType == ar.dataType && nullable == ar.nullable &&
+ metadata == ar.metadata && exprId == ar.exprId && qualifiers == ar.qualifiers
case _ => false
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index e2b97b27a6..45630a591d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.types._
-import org.apache.spark.util.collection.OpenHashSet
+import scala.collection.mutable.ArrayBuffer
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
@@ -244,12 +244,12 @@ private[sql] object Expand {
*/
private def buildNonSelectExprSet(
bitmask: Int,
- exprs: Seq[Expression]): OpenHashSet[Expression] = {
- val set = new OpenHashSet[Expression](2)
+ exprs: Seq[Expression]): ArrayBuffer[Expression] = {
+ val set = new ArrayBuffer[Expression](2)
var bit = exprs.length - 1
while (bit >= 0) {
- if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit))
+ if (((bitmask >> bit) & 1) == 0) set += exprs(bit)
bit -= 1
}
@@ -279,7 +279,7 @@ private[sql] object Expand {
(child.output :+ gid).map(expr => expr transformDown {
// TODO this causes a problem when a column is used both for grouping and aggregation.
- case x: Expression if nonSelectedGroupExprSet.contains(x) =>
+ case x: Expression if nonSelectedGroupExprSet.exists(_.semanticEquals(x)) =>
// if the input attribute in the Invalid Grouping Expression set of for this group
// replace it with constant null
Literal.create(null, expr.dataType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 86b9417477..f6fb31a2af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -235,17 +235,17 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
override def satisfies(required: Distribution): Boolean = required match {
case UnspecifiedDistribution => true
case ClusteredDistribution(requiredClustering) =>
- expressions.toSet.subsetOf(requiredClustering.toSet)
+ expressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
case _ => false
}
override def compatibleWith(other: Partitioning): Boolean = other match {
- case o: HashPartitioning => this == o
+ case o: HashPartitioning => this.semanticEquals(o)
case _ => false
}
override def guarantees(other: Partitioning): Boolean = other match {
- case o: HashPartitioning => this == o
+ case o: HashPartitioning => this.semanticEquals(o)
case _ => false
}
@@ -276,17 +276,17 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
val minSize = Seq(requiredOrdering.size, ordering.size).min
requiredOrdering.take(minSize) == ordering.take(minSize)
case ClusteredDistribution(requiredClustering) =>
- ordering.map(_.child).toSet.subsetOf(requiredClustering.toSet)
+ ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x)))
case _ => false
}
override def compatibleWith(other: Partitioning): Boolean = other match {
- case o: RangePartitioning => this == o
+ case o: RangePartitioning => this.semanticEquals(o)
case _ => false
}
override def guarantees(other: Partitioning): Boolean = other match {
- case o: RangePartitioning => this == o
+ case o: RangePartitioning => this.semanticEquals(o)
case _ => false
}
}