aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-03-25 19:21:54 -0700
committerMichael Armbrust <michael@databricks.com>2015-03-25 19:22:05 -0700
commit276ef1c3cfd44b5fc082e1a495fff22fbaf6add3 (patch)
treebdf891c9be37b04f4dae84ef3602acaf9c779264 /sql
parente87bf3713e684fa83165a1036d76f7a84f043775 (diff)
downloadspark-276ef1c3cfd44b5fc082e1a495fff22fbaf6add3.tar.gz
spark-276ef1c3cfd44b5fc082e1a495fff22fbaf6add3.tar.bz2
spark-276ef1c3cfd44b5fc082e1a495fff22fbaf6add3.zip
[SPARK-6463][SQL] AttributeSet.equal should compare size
Previously this could result in sets compare equals when in fact the right was a subset of the left. Based on #5133 by sisihj Author: sisihj <jun.hejun@huawei.com> Author: Michael Armbrust <michael@databricks.com> Closes #5194 from marmbrus/pr/5133 and squashes the following commits: 5ed4615 [Michael Armbrust] fix imports d4cbbc0 [Michael Armbrust] Add test cases 0a0834f [sisihj] AttributeSet.equal should compare size
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala82
2 files changed, 84 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
index f9ae85a5cf..11b4eb5c88 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
@@ -58,7 +58,8 @@ class AttributeSet private (val baseSet: Set[AttributeEquals])
/** Returns true if the members of this AttributeSet and other are the same. */
override def equals(other: Any): Boolean = other match {
- case otherSet: AttributeSet => baseSet.map(_.a).forall(otherSet.contains)
+ case otherSet: AttributeSet =>
+ otherSet.size == baseSet.size && baseSet.map(_.a).forall(otherSet.contains)
case _ => false
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala
new file mode 100644
index 0000000000..f2f3a84d19
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.types.IntegerType
+
+class AttributeSetSuite extends FunSuite {
+
+ val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1))
+ val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1))
+ val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3))
+ val aSet = AttributeSet(aLower :: Nil)
+
+ val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2))
+ val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2))
+ val bSet = AttributeSet(bUpper :: Nil)
+
+ val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil)
+
+ test("sanity check") {
+ assert(aUpper != aLower)
+ assert(bUpper != bLower)
+ }
+
+ test("checks by id not name") {
+ assert(aSet.contains(aUpper) === true)
+ assert(aSet.contains(aLower) === true)
+ assert(aSet.contains(fakeA) === false)
+
+ assert(aSet.contains(bUpper) === false)
+ assert(aSet.contains(bLower) === false)
+ }
+
+ test("++ preserves AttributeSet") {
+ assert((aSet ++ bSet).contains(aUpper) === true)
+ assert((aSet ++ bSet).contains(aLower) === true)
+ }
+
+ test("extracts all references references") {
+ val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil)
+ assert(addSet.contains(aUpper))
+ assert(addSet.contains(aLower))
+ assert(addSet.contains(bUpper))
+ assert(addSet.contains(bLower))
+ }
+
+ test("dedups attributes") {
+ assert(AttributeSet(aUpper :: aLower :: Nil).size === 1)
+ }
+
+ test("subset") {
+ assert(aSet.subsetOf(aAndBSet) === true)
+ assert(aAndBSet.subsetOf(aSet) === false)
+ }
+
+ test("equality") {
+ assert(aSet != aAndBSet)
+ assert(aAndBSet != aSet)
+ assert(aSet != bSet)
+ assert(bSet != aSet)
+
+ assert(aSet == aSet)
+ assert(aSet == AttributeSet(aUpper :: Nil))
+ }
+}