aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYijie Shen <henry.yijieshen@gmail.com>2015-08-13 13:33:39 +0800
committerCheng Lian <lian@databricks.com>2015-08-13 13:33:39 +0800
commitd0b18919d16e6a2f19159516bd2767b60b595279 (patch)
tree1415b92333850a066dd23f385a779d74e6768940
parentd7eb371eb6369a34e58a09179efe058c4101de9e (diff)
downloadspark-d0b18919d16e6a2f19159516bd2767b60b595279.tar.gz
spark-d0b18919d16e6a2f19159516bd2767b60b595279.tar.bz2
spark-d0b18919d16e6a2f19159516bd2767b60b595279.zip
[SPARK-9927] [SQL] Revert 8049 since it's pushing wrong filter down
I made a mistake in #8049 by casting literal value to attribute's data type, which would cause simply truncate the literal value and push a wrong filter down. JIRA: https://issues.apache.org/jira/browse/SPARK-9927 Author: Yijie Shen <henry.yijieshen@gmail.com> Closes #8157 from yjshen/rever8049.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala35
3 files changed, 3 insertions, 64 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 9eea2b0382..2a4c40db8b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -20,13 +20,13 @@ package org.apache.spark.sql.execution.datasources
import org.apache.spark.{Logging, TaskContext}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD}
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions}
+import org.apache.spark.sql.catalyst.{InternalRow, expressions}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.sources._
-import org.apache.spark.sql.types.{TimestampType, DateType, StringType, StructType}
+import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{SaveMode, Strategy, execution, sources, _}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{SerializableConfiguration, Utils}
@@ -343,17 +343,11 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
* and convert them.
*/
protected[sql] def selectFilters(filters: Seq[Expression]) = {
- import CatalystTypeConverters._
-
def translate(predicate: Expression): Option[Filter] = predicate match {
case expressions.EqualTo(a: Attribute, Literal(v, _)) =>
Some(sources.EqualTo(a.name, v))
case expressions.EqualTo(Literal(v, _), a: Attribute) =>
Some(sources.EqualTo(a.name, v))
- case expressions.EqualTo(Cast(a: Attribute, _), l: Literal) =>
- Some(sources.EqualTo(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType)))
- case expressions.EqualTo(l: Literal, Cast(a: Attribute, _)) =>
- Some(sources.EqualTo(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType)))
case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) =>
Some(sources.EqualNullSafe(a.name, v))
@@ -364,41 +358,21 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
Some(sources.GreaterThan(a.name, v))
case expressions.GreaterThan(Literal(v, _), a: Attribute) =>
Some(sources.LessThan(a.name, v))
- case expressions.GreaterThan(Cast(a: Attribute, _), l: Literal) =>
- Some(sources.GreaterThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType)))
- case expressions.GreaterThan(l: Literal, Cast(a: Attribute, _)) =>
- Some(sources.LessThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType)))
case expressions.LessThan(a: Attribute, Literal(v, _)) =>
Some(sources.LessThan(a.name, v))
case expressions.LessThan(Literal(v, _), a: Attribute) =>
Some(sources.GreaterThan(a.name, v))
- case expressions.LessThan(Cast(a: Attribute, _), l: Literal) =>
- Some(sources.LessThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType)))
- case expressions.LessThan(l: Literal, Cast(a: Attribute, _)) =>
- Some(sources.GreaterThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType)))
case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) =>
Some(sources.GreaterThanOrEqual(a.name, v))
case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) =>
Some(sources.LessThanOrEqual(a.name, v))
- case expressions.GreaterThanOrEqual(Cast(a: Attribute, _), l: Literal) =>
- Some(sources.GreaterThanOrEqual(a.name,
- convertToScala(Cast(l, a.dataType).eval(), a.dataType)))
- case expressions.GreaterThanOrEqual(l: Literal, Cast(a: Attribute, _)) =>
- Some(sources.LessThanOrEqual(a.name,
- convertToScala(Cast(l, a.dataType).eval(), a.dataType)))
case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) =>
Some(sources.LessThanOrEqual(a.name, v))
case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) =>
Some(sources.GreaterThanOrEqual(a.name, v))
- case expressions.LessThanOrEqual(Cast(a: Attribute, _), l: Literal) =>
- Some(sources.LessThanOrEqual(a.name,
- convertToScala(Cast(l, a.dataType).eval(), a.dataType)))
- case expressions.LessThanOrEqual(l: Literal, Cast(a: Attribute, _)) =>
- Some(sources.GreaterThanOrEqual(a.name,
- convertToScala(Cast(l, a.dataType).eval(), a.dataType)))
case expressions.InSet(a: Attribute, set) =>
Some(sources.In(a.name, set.toArray))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 281943e23f..8eab6a0adc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -284,7 +284,7 @@ private[sql] class JDBCRDD(
/**
* `filters`, but as a WHERE clause suitable for injection into a SQL query.
*/
- val filterWhereClause: String = {
+ private val filterWhereClause: String = {
val filterStrings = filters map compileFilter filter (_ != null)
if (filterStrings.size > 0) {
val sb = new StringBuilder("WHERE ")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index b9cfae51e8..e4dcf4c75d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -25,8 +25,6 @@ import org.h2.jdbc.JdbcSQLException
import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD
-import org.apache.spark.sql.execution.PhysicalRDD
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -150,18 +148,6 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
|OPTIONS (url '$url', dbtable 'TEST.FLTTYPES', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))
- conn.prepareStatement("create table test.decimals (a DECIMAL(7, 2), b DECIMAL(4, 0))").
- executeUpdate()
- conn.prepareStatement("insert into test.decimals values (12345.67, 1234)").executeUpdate()
- conn.prepareStatement("insert into test.decimals values (34567.89, 1428)").executeUpdate()
- conn.commit()
- sql(
- s"""
- |CREATE TEMPORARY TABLE decimals
- |USING org.apache.spark.sql.jdbc
- |OPTIONS (url '$url', dbtable 'TEST.DECIMALS', user 'testUser', password 'testPass')
- """.stripMargin.replaceAll("\n", " "))
-
conn.prepareStatement(
s"""
|create table test.nulltypes (a INT, b BOOLEAN, c TINYINT, d BINARY(20), e VARCHAR(20),
@@ -458,25 +444,4 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
assert(agg.getCatalystType(0, "", 1, null) === Some(LongType))
assert(agg.getCatalystType(1, "", 1, null) === Some(StringType))
}
-
- test("SPARK-9182: filters are not passed through to jdbc source") {
- def checkPushedFilter(query: String, filterStr: String): Unit = {
- val rddOpt = sql(query).queryExecution.executedPlan.collectFirst {
- case PhysicalRDD(_, rdd: JDBCRDD, _) => rdd
- }
- assert(rddOpt.isDefined)
- val pushedFilterStr = rddOpt.get.filterWhereClause
- assert(pushedFilterStr.contains(filterStr),
- s"Expected to push [$filterStr], actually we pushed [$pushedFilterStr]")
- }
-
- checkPushedFilter("select * from foobar where NAME = 'fred'", "NAME = 'fred'")
- checkPushedFilter("select * from inttypes where A > '15'", "A > 15")
- checkPushedFilter("select * from inttypes where C <= 20", "C <= 20")
- checkPushedFilter("select * from decimals where A > 1000", "A > 1000.00")
- checkPushedFilter("select * from decimals where A > 1000 AND A < 2000",
- "A > 1000.00 AND A < 2000.00")
- checkPushedFilter("select * from decimals where A = 2000 AND B > 20", "A = 2000.00 AND B > 20")
- checkPushedFilter("select * from timetypes where B > '1998-09-10'", "B > 1998-09-10")
- }
}