aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst/src/test')
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala37
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala14
2 files changed, 48 insertions, 3 deletions
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index d7eb13c50b..7beef71845 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -21,7 +21,8 @@ import scala.collection.immutable.HashSet
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType, BooleanType}
+import org.apache.spark.sql.RandomDataGenerator
+import org.apache.spark.sql.types._
class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -118,6 +119,23 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true)
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true)
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false)
+
+ val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
+ LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
+ primitiveTypes.map { t =>
+ val dataGen = RandomDataGenerator.forType(t, nullable = false).get
+ val inputData = Seq.fill(10) {
+ val value = dataGen.apply()
+ value match {
+ case d: Double if d.isNaN => 0.0d
+ case f: Float if f.isNaN => 0.0f
+ case _ => value
+ }
+ }
+ val input = inputData.map(Literal(_))
+ checkEvaluation(In(input(0), input.slice(1, 10)),
+ inputData.slice(1, 10).contains(inputData(0)))
+ }
}
test("INSET") {
@@ -134,6 +152,23 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(InSet(three, hS), false)
checkEvaluation(InSet(three, nS), false)
checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true)
+
+ val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
+ LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
+ primitiveTypes.map { t =>
+ val dataGen = RandomDataGenerator.forType(t, nullable = false).get
+ val inputData = Seq.fill(10) {
+ val value = dataGen.apply()
+ value match {
+ case d: Double if d.isNaN => 0.0d
+ case f: Float if f.isNaN => 0.0f
+ case _ => value
+ }
+ }
+ val input = inputData.map(Literal(_))
+ checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet),
+ inputData.slice(1, 10).contains(inputData(0)))
+ }
}
private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false).map(Literal(_))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
index 1d433275fe..6f7b5b9572 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
@@ -43,16 +43,26 @@ class OptimizeInSuite extends PlanTest {
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
- test("OptimizedIn test: In clause optimized to InSet") {
+ test("OptimizedIn test: In clause not optimized to InSet when less than 10 items") {
val originalQuery =
testRelation
.where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2))))
.analyze
val optimized = Optimize.execute(originalQuery.analyze)
+ comparePlans(optimized, originalQuery)
+ }
+
+ test("OptimizedIn test: In clause optimized to InSet when more than 10 items") {
+ val originalQuery =
+ testRelation
+ .where(In(UnresolvedAttribute("a"), (1 to 11).map(Literal(_))))
+ .analyze
+
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
- .where(InSet(UnresolvedAttribute("a"), HashSet[Any]() + 1 + 2))
+ .where(InSet(UnresolvedAttribute("a"), (1 to 11).toSet))
.analyze
comparePlans(optimized, correctAnswer)