aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorYijie Shen <henry.yijieshen@gmail.com>2015-08-12 19:54:00 +0800
committerCheng Lian <lian@databricks.com>2015-08-12 19:54:00 +0800
commit9d0822455ddc8d765440d58c463367a4d67ef456 (patch)
tree47b01cd45d890cf948d915e7115833dd68d73c10 /sql
parent741a29f98945538a475579ccc974cd42c1613be4 (diff)
downloadspark-9d0822455ddc8d765440d58c463367a4d67ef456.tar.gz
spark-9d0822455ddc8d765440d58c463367a4d67ef456.tar.bz2
spark-9d0822455ddc8d765440d58c463367a4d67ef456.zip
[SPARK-9182] [SQL] Filters are not passed through to jdbc source
This PR fixes unable to push filter down to JDBC source caused by `Cast` during pattern matching. While we are comparing columns of different type, there's a big chance we need a cast on the column, therefore not match the pattern directly on Attribute and would fail to push down. Author: Yijie Shen <henry.yijieshen@gmail.com> Closes #8049 from yjshen/jdbc_pushdown.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala34
3 files changed, 63 insertions, 3 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 2a4c40db8b..9eea2b0382 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -20,13 +20,13 @@ package org.apache.spark.sql.execution.datasources
import org.apache.spark.{Logging, TaskContext}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD}
-import org.apache.spark.sql.catalyst.{InternalRow, expressions}
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.sources._
-import org.apache.spark.sql.types.{StringType, StructType}
+import org.apache.spark.sql.types.{TimestampType, DateType, StringType, StructType}
import org.apache.spark.sql.{SaveMode, Strategy, execution, sources, _}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{SerializableConfiguration, Utils}
@@ -343,11 +343,17 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
* and convert them.
*/
protected[sql] def selectFilters(filters: Seq[Expression]) = {
+ import CatalystTypeConverters._
+
def translate(predicate: Expression): Option[Filter] = predicate match {
case expressions.EqualTo(a: Attribute, Literal(v, _)) =>
Some(sources.EqualTo(a.name, v))
case expressions.EqualTo(Literal(v, _), a: Attribute) =>
Some(sources.EqualTo(a.name, v))
+ case expressions.EqualTo(Cast(a: Attribute, _), l: Literal) =>
+ Some(sources.EqualTo(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType)))
+ case expressions.EqualTo(l: Literal, Cast(a: Attribute, _)) =>
+ Some(sources.EqualTo(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType)))
case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) =>
Some(sources.EqualNullSafe(a.name, v))
@@ -358,21 +364,41 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
Some(sources.GreaterThan(a.name, v))
case expressions.GreaterThan(Literal(v, _), a: Attribute) =>
Some(sources.LessThan(a.name, v))
+ case expressions.GreaterThan(Cast(a: Attribute, _), l: Literal) =>
+ Some(sources.GreaterThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType)))
+ case expressions.GreaterThan(l: Literal, Cast(a: Attribute, _)) =>
+ Some(sources.LessThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType)))
case expressions.LessThan(a: Attribute, Literal(v, _)) =>
Some(sources.LessThan(a.name, v))
case expressions.LessThan(Literal(v, _), a: Attribute) =>
Some(sources.GreaterThan(a.name, v))
+ case expressions.LessThan(Cast(a: Attribute, _), l: Literal) =>
+ Some(sources.LessThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType)))
+ case expressions.LessThan(l: Literal, Cast(a: Attribute, _)) =>
+ Some(sources.GreaterThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType)))
case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) =>
Some(sources.GreaterThanOrEqual(a.name, v))
case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) =>
Some(sources.LessThanOrEqual(a.name, v))
+ case expressions.GreaterThanOrEqual(Cast(a: Attribute, _), l: Literal) =>
+ Some(sources.GreaterThanOrEqual(a.name,
+ convertToScala(Cast(l, a.dataType).eval(), a.dataType)))
+ case expressions.GreaterThanOrEqual(l: Literal, Cast(a: Attribute, _)) =>
+ Some(sources.LessThanOrEqual(a.name,
+ convertToScala(Cast(l, a.dataType).eval(), a.dataType)))
case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) =>
Some(sources.LessThanOrEqual(a.name, v))
case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) =>
Some(sources.GreaterThanOrEqual(a.name, v))
+ case expressions.LessThanOrEqual(Cast(a: Attribute, _), l: Literal) =>
+ Some(sources.LessThanOrEqual(a.name,
+ convertToScala(Cast(l, a.dataType).eval(), a.dataType)))
+ case expressions.LessThanOrEqual(l: Literal, Cast(a: Attribute, _)) =>
+ Some(sources.GreaterThanOrEqual(a.name,
+ convertToScala(Cast(l, a.dataType).eval(), a.dataType)))
case expressions.InSet(a: Attribute, set) =>
Some(sources.In(a.name, set.toArray))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 8eab6a0adc..281943e23f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -284,7 +284,7 @@ private[sql] class JDBCRDD(
/**
* `filters`, but as a WHERE clause suitable for injection into a SQL query.
*/
- private val filterWhereClause: String = {
+ val filterWhereClause: String = {
val filterStrings = filters map compileFilter filter (_ != null)
if (filterStrings.size > 0) {
val sb = new StringBuilder("WHERE ")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 42f2449afb..b9cfae51e8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -25,6 +25,8 @@ import org.h2.jdbc.JdbcSQLException
import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD
+import org.apache.spark.sql.execution.PhysicalRDD
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -148,6 +150,18 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
|OPTIONS (url '$url', dbtable 'TEST.FLTTYPES', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))
+ conn.prepareStatement("create table test.decimals (a DECIMAL(7, 2), b DECIMAL(4, 0))").
+ executeUpdate()
+ conn.prepareStatement("insert into test.decimals values (12345.67, 1234)").executeUpdate()
+ conn.prepareStatement("insert into test.decimals values (34567.89, 1428)").executeUpdate()
+ conn.commit()
+ sql(
+ s"""
+ |CREATE TEMPORARY TABLE decimals
+ |USING org.apache.spark.sql.jdbc
+ |OPTIONS (url '$url', dbtable 'TEST.DECIMALS', user 'testUser', password 'testPass')
+ """.stripMargin.replaceAll("\n", " "))
+
conn.prepareStatement(
s"""
|create table test.nulltypes (a INT, b BOOLEAN, c TINYINT, d BINARY(20), e VARCHAR(20),
@@ -445,4 +459,24 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
assert(agg.getCatalystType(1, "", 1, null) === Some(StringType))
}
+ test("SPARK-9182: filters are not passed through to jdbc source") {
+ def checkPushedFilter(query: String, filterStr: String): Unit = {
+ val rddOpt = sql(query).queryExecution.executedPlan.collectFirst {
+ case PhysicalRDD(_, rdd: JDBCRDD, _) => rdd
+ }
+ assert(rddOpt.isDefined)
+ val pushedFilterStr = rddOpt.get.filterWhereClause
+ assert(pushedFilterStr.contains(filterStr),
+ s"Expected to push [$filterStr], actually we pushed [$pushedFilterStr]")
+ }
+
+ checkPushedFilter("select * from foobar where NAME = 'fred'", "NAME = 'fred'")
+ checkPushedFilter("select * from inttypes where A > '15'", "A > 15")
+ checkPushedFilter("select * from inttypes where C <= 20", "C <= 20")
+ checkPushedFilter("select * from decimals where A > 1000", "A > 1000.00")
+ checkPushedFilter("select * from decimals where A > 1000 AND A < 2000",
+ "A > 1000.00 AND A < 2000.00")
+ checkPushedFilter("select * from decimals where A = 2000 AND B > 20", "A = 2000.00 AND B > 20")
+ checkPushedFilter("select * from timetypes where B > '1998-09-10'", "B > 1998-09-10")
+ }
}