aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-08-25 01:06:36 -0700
committerReynold Xin <rxin@databricks.com>2015-08-25 01:06:36 -0700
commit7bc9a8c6249300ded31ea931c463d0a8f798e193 (patch)
treef443f097acdfed3dc0d31b5aa7dd7d620a7d6ddb /sql
parent2f493f7e3924b769160a16f73cccbebf21973b91 (diff)
downloadspark-7bc9a8c6249300ded31ea931c463d0a8f798e193.tar.gz
spark-7bc9a8c6249300ded31ea931c463d0a8f798e193.tar.bz2
spark-7bc9a8c6249300ded31ea931c463d0a8f798e193.zip
[SPARK-10195] [SQL] Data sources Filter should not expose internal types
Spark SQL's data sources API exposes Catalyst's internal types through its Filter interfaces. This is a problem because types like UTF8String are not stable developer APIs and should not be exposed to third-parties. This issue caused incompatibilities when upgrading our `spark-redshift` library to work against Spark 1.5.0. To avoid these issues in the future we should only expose public types through these Filter objects. This patch accomplishes this by using CatalystTypeConverters to add the appropriate conversions. Author: Josh Rosen <joshrosen@databricks.com> Closes #8403 from JoshRosen/datasources-internal-vs-external-types.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala67
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala7
4 files changed, 54 insertions, 41 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 2a4c40db8b..6c1ef6a6df 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.datasources
import org.apache.spark.{Logging, TaskContext}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD}
-import org.apache.spark.sql.catalyst.{InternalRow, expressions}
+import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
@@ -344,45 +345,47 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
*/
protected[sql] def selectFilters(filters: Seq[Expression]) = {
def translate(predicate: Expression): Option[Filter] = predicate match {
- case expressions.EqualTo(a: Attribute, Literal(v, _)) =>
- Some(sources.EqualTo(a.name, v))
- case expressions.EqualTo(Literal(v, _), a: Attribute) =>
- Some(sources.EqualTo(a.name, v))
-
- case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) =>
- Some(sources.EqualNullSafe(a.name, v))
- case expressions.EqualNullSafe(Literal(v, _), a: Attribute) =>
- Some(sources.EqualNullSafe(a.name, v))
-
- case expressions.GreaterThan(a: Attribute, Literal(v, _)) =>
- Some(sources.GreaterThan(a.name, v))
- case expressions.GreaterThan(Literal(v, _), a: Attribute) =>
- Some(sources.LessThan(a.name, v))
-
- case expressions.LessThan(a: Attribute, Literal(v, _)) =>
- Some(sources.LessThan(a.name, v))
- case expressions.LessThan(Literal(v, _), a: Attribute) =>
- Some(sources.GreaterThan(a.name, v))
-
- case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) =>
- Some(sources.GreaterThanOrEqual(a.name, v))
- case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) =>
- Some(sources.LessThanOrEqual(a.name, v))
-
- case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) =>
- Some(sources.LessThanOrEqual(a.name, v))
- case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) =>
- Some(sources.GreaterThanOrEqual(a.name, v))
+ case expressions.EqualTo(a: Attribute, Literal(v, t)) =>
+ Some(sources.EqualTo(a.name, convertToScala(v, t)))
+ case expressions.EqualTo(Literal(v, t), a: Attribute) =>
+ Some(sources.EqualTo(a.name, convertToScala(v, t)))
+
+ case expressions.EqualNullSafe(a: Attribute, Literal(v, t)) =>
+ Some(sources.EqualNullSafe(a.name, convertToScala(v, t)))
+ case expressions.EqualNullSafe(Literal(v, t), a: Attribute) =>
+ Some(sources.EqualNullSafe(a.name, convertToScala(v, t)))
+
+ case expressions.GreaterThan(a: Attribute, Literal(v, t)) =>
+ Some(sources.GreaterThan(a.name, convertToScala(v, t)))
+ case expressions.GreaterThan(Literal(v, t), a: Attribute) =>
+ Some(sources.LessThan(a.name, convertToScala(v, t)))
+
+ case expressions.LessThan(a: Attribute, Literal(v, t)) =>
+ Some(sources.LessThan(a.name, convertToScala(v, t)))
+ case expressions.LessThan(Literal(v, t), a: Attribute) =>
+ Some(sources.GreaterThan(a.name, convertToScala(v, t)))
+
+ case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, t)) =>
+ Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t)))
+ case expressions.GreaterThanOrEqual(Literal(v, t), a: Attribute) =>
+ Some(sources.LessThanOrEqual(a.name, convertToScala(v, t)))
+
+ case expressions.LessThanOrEqual(a: Attribute, Literal(v, t)) =>
+ Some(sources.LessThanOrEqual(a.name, convertToScala(v, t)))
+ case expressions.LessThanOrEqual(Literal(v, t), a: Attribute) =>
+ Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t)))
case expressions.InSet(a: Attribute, set) =>
- Some(sources.In(a.name, set.toArray))
+ val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType)
+ Some(sources.In(a.name, set.toArray.map(toScala)))
// Because we only convert In to InSet in Optimizer when there are more than certain
// items. So it is possible we still get an In expression here that needs to be pushed
// down.
case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) =>
val hSet = list.map(e => e.eval(EmptyRow))
- Some(sources.In(a.name, hSet.toArray))
+ val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType)
+ Some(sources.In(a.name, hSet.toArray.map(toScala)))
case expressions.IsNull(a: Attribute) =>
Some(sources.IsNull(a.name))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index e537d631f4..730d88b024 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -262,7 +262,7 @@ private[sql] class JDBCRDD(
* Converts value to SQL expression.
*/
private def compileValue(value: Any): Any = value match {
- case stringValue: UTF8String => s"'${escapeSql(stringValue.toString)}'"
+ case stringValue: String => s"'${escapeSql(stringValue)}'"
case _ => value
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
index c74c838863..c6b3fe7900 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
@@ -32,7 +32,6 @@ import org.apache.spark.SparkEnv
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.sources
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
private[sql] object ParquetFilters {
val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter"
@@ -65,7 +64,7 @@ private[sql] object ParquetFilters {
case StringType =>
(n: String, v: Any) => FilterApi.eq(
binaryColumn(n),
- Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull)
+ Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))).orNull)
case BinaryType =>
(n: String, v: Any) => FilterApi.eq(
binaryColumn(n),
@@ -86,7 +85,7 @@ private[sql] object ParquetFilters {
case StringType =>
(n: String, v: Any) => FilterApi.notEq(
binaryColumn(n),
- Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull)
+ Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))).orNull)
case BinaryType =>
(n: String, v: Any) => FilterApi.notEq(
binaryColumn(n),
@@ -104,7 +103,8 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
- FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
+ FilterApi.lt(binaryColumn(n),
+ Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8")))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -121,7 +121,8 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
- FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
+ FilterApi.ltEq(binaryColumn(n),
+ Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8")))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -138,7 +139,8 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
- FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
+ FilterApi.gt(binaryColumn(n),
+ Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8")))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -155,7 +157,8 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
- FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
+ FilterApi.gtEq(binaryColumn(n),
+ Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8")))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -177,7 +180,7 @@ private[sql] object ParquetFilters {
case StringType =>
(n: String, v: Set[Any]) =>
FilterApi.userDefined(binaryColumn(n),
- SetInFilter(v.map(e => Binary.fromByteArray(e.asInstanceOf[UTF8String].getBytes))))
+ SetInFilter(v.map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8")))))
case BinaryType =>
(n: String, v: Set[Any]) =>
FilterApi.userDefined(binaryColumn(n),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
index c81c3d3982..68ce37c000 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.sources
import scala.language.existentials
import org.apache.spark.rdd.RDD
+import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.sql._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@@ -78,6 +79,9 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL
case StringStartsWith("c", v) => _.startsWith(v)
case StringEndsWith("c", v) => _.endsWith(v)
case StringContains("c", v) => _.contains(v)
+ case EqualTo("c", v: String) => _.equals(v)
+ case EqualTo("c", v: UTF8String) => sys.error("UTF8String should not appear in filters")
+ case In("c", values) => (s: String) => values.map(_.asInstanceOf[String]).toSet.contains(s)
case _ => (c: String) => true
}
@@ -237,6 +241,9 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext {
testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1)
testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0)
+ testPushDown("SELECT c FROM oneToTenFiltered WHERE c = 'aaaaaAAAAA'", 1)
+ testPushDown("SELECT c FROM oneToTenFiltered WHERE c IN ('aaaaaAAAAA', 'foo')", 1)
+
def testPushDown(sqlString: String, expectedCount: Int): Unit = {
test(s"PushDown Returns $expectedCount: $sqlString") {
val queryExecution = sql(sqlString).queryExecution