diff options
author | gatorsmile <gatorsmile@gmail.com> | 2015-11-16 15:22:12 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-11-16 15:22:12 -0800 |
commit | 75ee12f09c2645c1ad682764d512965f641eb5c2 (patch) | |
tree | 8fd6c52cf5cd800e1faac70097c3cec1a6e1fb0b /sql/catalyst | |
parent | 31296628ac7cd7be71e0edca335dc8604f62bb47 (diff) | |
download | spark-75ee12f09c2645c1ad682764d512965f641eb5c2.tar.gz spark-75ee12f09c2645c1ad682764d512965f641eb5c2.tar.bz2 spark-75ee12f09c2645c1ad682764d512965f641eb5c2.zip |
[SPARK-8658][SQL] AttributeReference's equals method compares all the members
This fix is to change the equals method to check all of the specified fields for equality of AttributeReference.
Author: gatorsmile <gatorsmile@gmail.com>
Closes #9216 from gatorsmile/namedExpressEqual.
Diffstat (limited to 'sql/catalyst')
3 files changed, 14 insertions, 12 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index f80bcfcb0b..e3daddace2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -194,7 +194,9 @@ case class AttributeReference( def sameRef(other: AttributeReference): Boolean = this.exprId == other.exprId override def equals(other: Any): Boolean = other match { - case ar: AttributeReference => name == ar.name && exprId == ar.exprId && dataType == ar.dataType + case ar: AttributeReference => + name == ar.name && dataType == ar.dataType && nullable == ar.nullable && + metadata == ar.metadata && exprId == ar.exprId && qualifiers == ar.qualifiers case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index e2b97b27a6..45630a591d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.OpenHashSet +import scala.collection.mutable.ArrayBuffer case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) @@ -244,12 +244,12 @@ private[sql] object Expand { */ private def buildNonSelectExprSet( bitmask: Int, - exprs: Seq[Expression]): OpenHashSet[Expression] = { - val set = new OpenHashSet[Expression](2) + exprs: Seq[Expression]): ArrayBuffer[Expression] = { + val set = new ArrayBuffer[Expression](2) var bit = exprs.length - 1 while (bit >= 0) { - if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit)) + if (((bitmask >> bit) & 1) == 0) set += exprs(bit) bit -= 1 } @@ -279,7 +279,7 @@ private[sql] object Expand { (child.output :+ gid).map(expr => expr transformDown { // TODO this causes a problem when a column is used both for grouping and aggregation. - case x: Expression if nonSelectedGroupExprSet.contains(x) => + case x: Expression if nonSelectedGroupExprSet.exists(_.semanticEquals(x)) => // if the input attribute in the Invalid Grouping Expression set of for this group // replace it with constant null Literal.create(null, expr.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 86b9417477..f6fb31a2af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -235,17 +235,17 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true case ClusteredDistribution(requiredClustering) => - expressions.toSet.subsetOf(requiredClustering.toSet) + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) case _ => false } override def compatibleWith(other: Partitioning): Boolean = other match { - case o: HashPartitioning => this == o + case o: HashPartitioning => this.semanticEquals(o) case _ => false } override def guarantees(other: Partitioning): Boolean = other match { - case o: HashPartitioning => this == o + case o: HashPartitioning => this.semanticEquals(o) case _ => false } @@ -276,17 +276,17 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) case ClusteredDistribution(requiredClustering) => - ordering.map(_.child).toSet.subsetOf(requiredClustering.toSet) + ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) case _ => false } override def compatibleWith(other: Partitioning): Boolean = other match { - case o: RangePartitioning => this == o + case o: RangePartitioning => this.semanticEquals(o) case _ => false } override def guarantees(other: Partitioning): Boolean = other match { - case o: RangePartitioning => this == o + case o: RangePartitioning => this.semanticEquals(o) case _ => false } } |