aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala67
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala7
4 files changed, 54 insertions, 41 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 2a4c40db8b..6c1ef6a6df 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.datasources
import org.apache.spark.{Logging, TaskContext}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD}
-import org.apache.spark.sql.catalyst.{InternalRow, expressions}
+import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
@@ -344,45 +345,47 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
*/
protected[sql] def selectFilters(filters: Seq[Expression]) = {
def translate(predicate: Expression): Option[Filter] = predicate match {
- case expressions.EqualTo(a: Attribute, Literal(v, _)) =>
- Some(sources.EqualTo(a.name, v))
- case expressions.EqualTo(Literal(v, _), a: Attribute) =>
- Some(sources.EqualTo(a.name, v))
-
- case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) =>
- Some(sources.EqualNullSafe(a.name, v))
- case expressions.EqualNullSafe(Literal(v, _), a: Attribute) =>
- Some(sources.EqualNullSafe(a.name, v))
-
- case expressions.GreaterThan(a: Attribute, Literal(v, _)) =>
- Some(sources.GreaterThan(a.name, v))
- case expressions.GreaterThan(Literal(v, _), a: Attribute) =>
- Some(sources.LessThan(a.name, v))
-
- case expressions.LessThan(a: Attribute, Literal(v, _)) =>
- Some(sources.LessThan(a.name, v))
- case expressions.LessThan(Literal(v, _), a: Attribute) =>
- Some(sources.GreaterThan(a.name, v))
-
- case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) =>
- Some(sources.GreaterThanOrEqual(a.name, v))
- case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) =>
- Some(sources.LessThanOrEqual(a.name, v))
-
- case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) =>
- Some(sources.LessThanOrEqual(a.name, v))
- case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) =>
- Some(sources.GreaterThanOrEqual(a.name, v))
+ case expressions.EqualTo(a: Attribute, Literal(v, t)) =>
+ Some(sources.EqualTo(a.name, convertToScala(v, t)))
+ case expressions.EqualTo(Literal(v, t), a: Attribute) =>
+ Some(sources.EqualTo(a.name, convertToScala(v, t)))
+
+ case expressions.EqualNullSafe(a: Attribute, Literal(v, t)) =>
+ Some(sources.EqualNullSafe(a.name, convertToScala(v, t)))
+ case expressions.EqualNullSafe(Literal(v, t), a: Attribute) =>
+ Some(sources.EqualNullSafe(a.name, convertToScala(v, t)))
+
+ case expressions.GreaterThan(a: Attribute, Literal(v, t)) =>
+ Some(sources.GreaterThan(a.name, convertToScala(v, t)))
+ case expressions.GreaterThan(Literal(v, t), a: Attribute) =>
+ Some(sources.LessThan(a.name, convertToScala(v, t)))
+
+ case expressions.LessThan(a: Attribute, Literal(v, t)) =>
+ Some(sources.LessThan(a.name, convertToScala(v, t)))
+ case expressions.LessThan(Literal(v, t), a: Attribute) =>
+ Some(sources.GreaterThan(a.name, convertToScala(v, t)))
+
+ case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, t)) =>
+ Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t)))
+ case expressions.GreaterThanOrEqual(Literal(v, t), a: Attribute) =>
+ Some(sources.LessThanOrEqual(a.name, convertToScala(v, t)))
+
+ case expressions.LessThanOrEqual(a: Attribute, Literal(v, t)) =>
+ Some(sources.LessThanOrEqual(a.name, convertToScala(v, t)))
+ case expressions.LessThanOrEqual(Literal(v, t), a: Attribute) =>
+ Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t)))
case expressions.InSet(a: Attribute, set) =>
- Some(sources.In(a.name, set.toArray))
+ val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType)
+ Some(sources.In(a.name, set.toArray.map(toScala)))
// Because we only convert In to InSet in Optimizer when there are more than certain
// items. So it is possible we still get an In expression here that needs to be pushed
// down.
case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) =>
val hSet = list.map(e => e.eval(EmptyRow))
- Some(sources.In(a.name, hSet.toArray))
+ val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType)
+ Some(sources.In(a.name, hSet.toArray.map(toScala)))
case expressions.IsNull(a: Attribute) =>
Some(sources.IsNull(a.name))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index e537d631f4..730d88b024 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -262,7 +262,7 @@ private[sql] class JDBCRDD(
* Converts value to SQL expression.
*/
private def compileValue(value: Any): Any = value match {
- case stringValue: UTF8String => s"'${escapeSql(stringValue.toString)}'"
+ case stringValue: String => s"'${escapeSql(stringValue)}'"
case _ => value
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
index c74c838863..c6b3fe7900 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
@@ -32,7 +32,6 @@ import org.apache.spark.SparkEnv
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.sources
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
private[sql] object ParquetFilters {
val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter"
@@ -65,7 +64,7 @@ private[sql] object ParquetFilters {
case StringType =>
(n: String, v: Any) => FilterApi.eq(
binaryColumn(n),
- Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull)
+ Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))).orNull)
case BinaryType =>
(n: String, v: Any) => FilterApi.eq(
binaryColumn(n),
@@ -86,7 +85,7 @@ private[sql] object ParquetFilters {
case StringType =>
(n: String, v: Any) => FilterApi.notEq(
binaryColumn(n),
- Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull)
+ Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))).orNull)
case BinaryType =>
(n: String, v: Any) => FilterApi.notEq(
binaryColumn(n),
@@ -104,7 +103,8 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
- FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
+ FilterApi.lt(binaryColumn(n),
+ Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8")))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -121,7 +121,8 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
- FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
+ FilterApi.ltEq(binaryColumn(n),
+ Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8")))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -138,7 +139,8 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
- FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
+ FilterApi.gt(binaryColumn(n),
+ Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8")))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -155,7 +157,8 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
- FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
+ FilterApi.gtEq(binaryColumn(n),
+ Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8")))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -177,7 +180,7 @@ private[sql] object ParquetFilters {
case StringType =>
(n: String, v: Set[Any]) =>
FilterApi.userDefined(binaryColumn(n),
- SetInFilter(v.map(e => Binary.fromByteArray(e.asInstanceOf[UTF8String].getBytes))))
+ SetInFilter(v.map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8")))))
case BinaryType =>
(n: String, v: Set[Any]) =>
FilterApi.userDefined(binaryColumn(n),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
index c81c3d3982..68ce37c000 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.sources
import scala.language.existentials
import org.apache.spark.rdd.RDD
+import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.sql._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@@ -78,6 +79,9 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL
case StringStartsWith("c", v) => _.startsWith(v)
case StringEndsWith("c", v) => _.endsWith(v)
case StringContains("c", v) => _.contains(v)
+ case EqualTo("c", v: String) => _.equals(v)
+ case EqualTo("c", v: UTF8String) => sys.error("UTF8String should not appear in filters")
+ case In("c", values) => (s: String) => values.map(_.asInstanceOf[String]).toSet.contains(s)
case _ => (c: String) => true
}
@@ -237,6 +241,9 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext {
testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1)
testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0)
+ testPushDown("SELECT c FROM oneToTenFiltered WHERE c = 'aaaaaAAAAA'", 1)
+ testPushDown("SELECT c FROM oneToTenFiltered WHERE c IN ('aaaaaAAAAA', 'foo')", 1)
+
def testPushDown(sqlString: String, expectedCount: Int): Unit = {
test(s"PushDown Returns $expectedCount: $sqlString") {
val queryExecution = sql(sqlString).queryExecution