diff options
author | Michael Armbrust <michael@databricks.com> | 2015-03-25 19:21:54 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-03-25 19:22:05 -0700 |
commit | 276ef1c3cfd44b5fc082e1a495fff22fbaf6add3 (patch) | |
tree | bdf891c9be37b04f4dae84ef3602acaf9c779264 /sql | |
parent | e87bf3713e684fa83165a1036d76f7a84f043775 (diff) | |
download | spark-276ef1c3cfd44b5fc082e1a495fff22fbaf6add3.tar.gz spark-276ef1c3cfd44b5fc082e1a495fff22fbaf6add3.tar.bz2 spark-276ef1c3cfd44b5fc082e1a495fff22fbaf6add3.zip |
[SPARK-6463][SQL] AttributeSet.equal should compare size
Previously this could result in sets compare equals when in fact the right was a subset of the left.
Based on #5133 by sisihj
Author: sisihj <jun.hejun@huawei.com>
Author: Michael Armbrust <michael@databricks.com>
Closes #5194 from marmbrus/pr/5133 and squashes the following commits:
5ed4615 [Michael Armbrust] fix imports
d4cbbc0 [Michael Armbrust] Add test cases
0a0834f [sisihj] AttributeSet.equal should compare size
Diffstat (limited to 'sql')
2 files changed, 84 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index f9ae85a5cf..11b4eb5c88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -58,7 +58,8 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) /** Returns true if the members of this AttributeSet and other are the same. */ override def equals(other: Any): Boolean = other match { - case otherSet: AttributeSet => baseSet.map(_.a).forall(otherSet.contains) + case otherSet: AttributeSet => + otherSet.size == baseSet.size && baseSet.map(_.a).forall(otherSet.contains) case _ => false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala new file mode 100644 index 0000000000..f2f3a84d19 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.scalatest.FunSuite + +import org.apache.spark.sql.types.IntegerType + +class AttributeSetSuite extends FunSuite { + + val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1)) + val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1)) + val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3)) + val aSet = AttributeSet(aLower :: Nil) + + val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2)) + val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2)) + val bSet = AttributeSet(bUpper :: Nil) + + val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil) + + test("sanity check") { + assert(aUpper != aLower) + assert(bUpper != bLower) + } + + test("checks by id not name") { + assert(aSet.contains(aUpper) === true) + assert(aSet.contains(aLower) === true) + assert(aSet.contains(fakeA) === false) + + assert(aSet.contains(bUpper) === false) + assert(aSet.contains(bLower) === false) + } + + test("++ preserves AttributeSet") { + assert((aSet ++ bSet).contains(aUpper) === true) + assert((aSet ++ bSet).contains(aLower) === true) + } + + test("extracts all references references") { + val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil) + assert(addSet.contains(aUpper)) + assert(addSet.contains(aLower)) + assert(addSet.contains(bUpper)) + assert(addSet.contains(bLower)) + } + + test("dedups attributes") { + assert(AttributeSet(aUpper :: aLower :: Nil).size === 1) + } + + test("subset") { + assert(aSet.subsetOf(aAndBSet) === true) + assert(aAndBSet.subsetOf(aSet) === false) + } + + test("equality") { + assert(aSet != aAndBSet) + assert(aAndBSet != aSet) + assert(aSet != bSet) + assert(bSet != aSet) + + assert(aSet == aSet) + assert(aSet == AttributeSet(aUpper :: Nil)) + } +} |