diff options
Diffstat (limited to 'sql/catalyst/src/test')
2 files changed, 48 insertions, 3 deletions
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index d7eb13c50b..7beef71845 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -21,7 +21,8 @@ import scala.collection.immutable.HashSet import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType, BooleanType} +import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.types._ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -118,6 +119,23 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true) checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true) checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false) + + val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType, + LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType) + primitiveTypes.map { t => + val dataGen = RandomDataGenerator.forType(t, nullable = false).get + val inputData = Seq.fill(10) { + val value = dataGen.apply() + value match { + case d: Double if d.isNaN => 0.0d + case f: Float if f.isNaN => 0.0f + case _ => value + } + } + val input = inputData.map(Literal(_)) + checkEvaluation(In(input(0), input.slice(1, 10)), + inputData.slice(1, 10).contains(inputData(0))) + } } test("INSET") { @@ -134,6 +152,23 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(InSet(three, hS), false) checkEvaluation(InSet(three, nS), false) checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true) + + val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType, + LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType) + primitiveTypes.map { t => + val dataGen = RandomDataGenerator.forType(t, nullable = false).get + val inputData = Seq.fill(10) { + val value = dataGen.apply() + value match { + case d: Double if d.isNaN => 0.0d + case f: Float if f.isNaN => 0.0f + case _ => value + } + } + val input = inputData.map(Literal(_)) + checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet), + inputData.slice(1, 10).contains(inputData(0))) + } } private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false).map(Literal(_)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 1d433275fe..6f7b5b9572 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -43,16 +43,26 @@ class OptimizeInSuite extends PlanTest { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - test("OptimizedIn test: In clause optimized to InSet") { + test("OptimizedIn test: In clause not optimized to InSet when less than 10 items") { val originalQuery = testRelation .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, originalQuery) + } + + test("OptimizedIn test: In clause optimized to InSet when more than 10 items") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), (1 to 11).map(Literal(_)))) + .analyze + + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(InSet(UnresolvedAttribute("a"), HashSet[Any]() + 1 + 2)) + .where(InSet(UnresolvedAttribute("a"), (1 to 11).toSet)) .analyze comparePlans(optimized, correctAnswer) |