From 276ef1c3cfd44b5fc082e1a495fff22fbaf6add3 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 25 Mar 2015 19:21:54 -0700 Subject: [SPARK-6463][SQL] AttributeSet.equal should compare size Previously this could result in sets compare equals when in fact the right was a subset of the left. Based on #5133 by sisihj Author: sisihj Author: Michael Armbrust Closes #5194 from marmbrus/pr/5133 and squashes the following commits: 5ed4615 [Michael Armbrust] fix imports d4cbbc0 [Michael Armbrust] Add test cases 0a0834f [sisihj] AttributeSet.equal should compare size --- .../sql/catalyst/expressions/AttributeSet.scala | 3 +- .../catalyst/expressions/AttributeSetSuite.scala | 82 ++++++++++++++++++++++ 2 files changed, 84 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index f9ae85a5cf..11b4eb5c88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -58,7 +58,8 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) /** Returns true if the members of this AttributeSet and other are the same. */ override def equals(other: Any): Boolean = other match { - case otherSet: AttributeSet => baseSet.map(_.a).forall(otherSet.contains) + case otherSet: AttributeSet => + otherSet.size == baseSet.size && baseSet.map(_.a).forall(otherSet.contains) case _ => false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala new file mode 100644 index 0000000000..f2f3a84d19 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.scalatest.FunSuite + +import org.apache.spark.sql.types.IntegerType + +class AttributeSetSuite extends FunSuite { + + val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1)) + val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1)) + val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3)) + val aSet = AttributeSet(aLower :: Nil) + + val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2)) + val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2)) + val bSet = AttributeSet(bUpper :: Nil) + + val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil) + + test("sanity check") { + assert(aUpper != aLower) + assert(bUpper != bLower) + } + + test("checks by id not name") { + assert(aSet.contains(aUpper) === true) + assert(aSet.contains(aLower) === true) + assert(aSet.contains(fakeA) === false) + + assert(aSet.contains(bUpper) === false) + assert(aSet.contains(bLower) === false) + } + + test("++ preserves AttributeSet") { + assert((aSet ++ bSet).contains(aUpper) === true) + assert((aSet ++ bSet).contains(aLower) === true) + } + + test("extracts all references references") { + val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil) + assert(addSet.contains(aUpper)) + assert(addSet.contains(aLower)) + assert(addSet.contains(bUpper)) + assert(addSet.contains(bLower)) + } + + test("dedups attributes") { + assert(AttributeSet(aUpper :: aLower :: Nil).size === 1) + } + + test("subset") { + assert(aSet.subsetOf(aAndBSet) === true) + assert(aAndBSet.subsetOf(aSet) === false) + } + + test("equality") { + assert(aSet != aAndBSet) + assert(aAndBSet != aSet) + assert(aSet != bSet) + assert(bSet != aSet) + + assert(aSet == aSet) + assert(aSet == AttributeSet(aUpper :: Nil)) + } +} -- cgit v1.2.3