aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala28
2 files changed, 25 insertions, 14 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 3cdca4e9dd..acfbbace60 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -156,12 +156,11 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE
/** A base trait for functions that compare two strings, returning a boolean. */
trait StringComparison {
- self: BinaryExpression =>
+ self: BinaryPredicate =>
- type EvaluatedType = Any
+ override type EvaluatedType = Any
override def nullable: Boolean = left.nullable || right.nullable
- override def dataType: DataType = BooleanType
def compare(l: String, r: String): Boolean
@@ -184,7 +183,7 @@ trait StringComparison {
* A function that returns true if the string `left` contains the string `right`.
*/
case class Contains(left: Expression, right: Expression)
- extends BinaryExpression with StringComparison {
+ extends BinaryPredicate with StringComparison {
override def compare(l: String, r: String): Boolean = l.contains(r)
}
@@ -192,7 +191,7 @@ case class Contains(left: Expression, right: Expression)
* A function that returns true if the string `left` starts with the string `right`.
*/
case class StartsWith(left: Expression, right: Expression)
- extends BinaryExpression with StringComparison {
+ extends BinaryPredicate with StringComparison {
override def compare(l: String, r: String): Boolean = l.startsWith(r)
}
@@ -200,7 +199,7 @@ case class StartsWith(left: Expression, right: Expression)
* A function that returns true if the string `left` ends with the string `right`.
*/
case class EndsWith(left: Expression, right: Expression)
- extends BinaryExpression with StringComparison {
+ extends BinaryPredicate with StringComparison {
override def compare(l: String, r: String): Boolean = l.endsWith(r)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
index 72ddc0ea2c..773bd1602d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
@@ -45,7 +45,9 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL
val rowBuilders = requiredColumns.map {
case "a" => (i: Int) => Seq(i)
case "b" => (i: Int) => Seq(i * 2)
- case "c" => (i: Int) => Seq((i - 1 + 'a').toChar.toString * 10)
+ case "c" => (i: Int) =>
+ val c = (i - 1 + 'a').toChar.toString
+ Seq(c * 5 + c.toUpperCase() * 5)
}
FiltersPushed.list = filters
@@ -77,7 +79,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL
}
def eval(a: Int) = {
- val c = (a - 1 + 'a').toChar.toString * 10
+ val c = (a - 1 + 'a').toChar.toString * 5 + (a - 1 + 'a').toChar.toString.toUpperCase() * 5
!filters.map(translateFilterOnA(_)(a)).contains(false) &&
!filters.map(translateFilterOnC(_)(c)).contains(false)
}
@@ -110,7 +112,8 @@ class FilteredScanSuite extends DataSourceTest {
sqlTest(
"SELECT * FROM oneToTenFiltered",
- (1 to 10).map(i => Row(i, i * 2, (i - 1 + 'a').toChar.toString * 10)).toSeq)
+ (1 to 10).map(i => Row(i, i * 2, (i - 1 + 'a').toChar.toString * 5
+ + (i - 1 + 'a').toChar.toString.toUpperCase() * 5)).toSeq)
sqlTest(
"SELECT a, b FROM oneToTenFiltered",
@@ -182,15 +185,15 @@ class FilteredScanSuite extends DataSourceTest {
sqlTest(
"SELECT a, b, c FROM oneToTenFiltered WHERE c like 'c%'",
- Seq(Row(3, 3 * 2, "c" * 10)))
+ Seq(Row(3, 3 * 2, "c" * 5 + "C" * 5)))
sqlTest(
- "SELECT a, b, c FROM oneToTenFiltered WHERE c like 'd%'",
- Seq(Row(4, 4 * 2, "d" * 10)))
+ "SELECT a, b, c FROM oneToTenFiltered WHERE c like '%D'",
+ Seq(Row(4, 4 * 2, "d" * 5 + "D" * 5)))
sqlTest(
- "SELECT a, b, c FROM oneToTenFiltered WHERE c like '%e%'",
- Seq(Row(5, 5 * 2, "e" * 10)))
+ "SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'",
+ Seq(Row(5, 5 * 2, "e" * 5 + "E" * 5)))
testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1)
testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1)
@@ -222,6 +225,15 @@ class FilteredScanSuite extends DataSourceTest {
testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", 4)
testPushDown("SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)", 5)
+ testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'c%'", 1)
+ testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'C%'", 0)
+
+ testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%D'", 1)
+ testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%d'", 0)
+
+ testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1)
+ testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0)
+
def testPushDown(sqlString: String, expectedCount: Int): Unit = {
test(s"PushDown Returns $expectedCount: $sqlString") {
val queryExecution = sql(sqlString).queryExecution