diff options
Diffstat (limited to 'sql/core')
5 files changed, 21 insertions, 3 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index 9b8c6a56b9..954e86822d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -108,5 +108,7 @@ private[sql] object DataSourceStrategy extends Strategy { case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => LessThanOrEqual(a.name, v) case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => GreaterThanOrEqual(a.name, v) + + case expressions.InSet(a: Attribute, set) => In(a.name, set.toArray) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala index 82a2cf8402..4d87f6817d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala @@ -41,8 +41,7 @@ private[sql] case class LogicalRelation(relation: BaseRelation) } @transient override lazy val statistics = Statistics( - // TODO: Allow datasources to provide statistics as well. - sizeInBytes = BigInt(relation.sqlContext.defaultSizeInBytes) + sizeInBytes = BigInt(relation.sizeInBytes) ) /** Used to lookup original attribute capitalization */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index e72a2aeb8f..4a9fefc12b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -24,3 +24,4 @@ case class GreaterThan(attribute: String, value: Any) extends Filter case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter case class LessThan(attribute: String, value: Any) extends Filter case class LessThanOrEqual(attribute: String, value: Any) extends Filter +case class In(attribute: String, values: Array[Any]) extends Filter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index ac3bf9d8e1..861638b1e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.sources import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext, StructType} +import org.apache.spark.sql.{SQLConf, Row, SQLContext, StructType} import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute} /** @@ -53,6 +53,15 @@ trait RelationProvider { abstract class BaseRelation { def sqlContext: SQLContext def schema: StructType + + /** + * Returns an estimated size of this relation in bytes. This information is used by the planner + * to decided when it is safe to broadcast a relation and can be overridden by sources that + * know the size ahead of time. By default, the system will assume that tables are too + * large to broadcast. This method will be called multiple times during query planning + * and thus should not perform expensive operations for each invocation. + */ + def sizeInBytes = sqlContext.defaultSizeInBytes } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 8b2f1591d5..939b3c0c66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -51,6 +51,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v case GreaterThan("a", v: Int) => (a: Int) => a > v case GreaterThanOrEqual("a", v: Int) => (a: Int) => a >= v + case In("a", values) => (a: Int) => values.map(_.asInstanceOf[Int]).toSet.contains(a) } def eval(a: Int) = !filterFunctions.map(_(a)).contains(false) @@ -122,6 +123,10 @@ class FilteredScanSuite extends DataSourceTest { Seq(1).map(i => Row(i, i * 2)).toSeq) sqlTest( + "SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", + Seq(1,3,5).map(i => Row(i, i * 2)).toSeq) + + sqlTest( "SELECT * FROM oneToTenFiltered WHERE A = 1", Seq(1).map(i => Row(i, i * 2)).toSeq) @@ -150,6 +155,8 @@ class FilteredScanSuite extends DataSourceTest { testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1 AND a < 10", 8) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", 3) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0) testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10) |