aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@databricks.com>2017-04-21 10:06:12 +0800
committerWenchen Fan <wenchen@databricks.com>2017-04-21 10:06:12 +0800
commit760c8d088df1d35d7b8942177d47bc1677daf143 (patch)
tree86e975adf8963dd5455a2f9c11607a9974c21373 /sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
parent0368eb9d86634c83b3140ce3190cb9e0d0b7fd86 (diff)
downloadspark-760c8d088df1d35d7b8942177d47bc1677daf143.tar.gz
spark-760c8d088df1d35d7b8942177d47bc1677daf143.tar.bz2
spark-760c8d088df1d35d7b8942177d47bc1677daf143.zip
[SPARK-20329][SQL] Make timezone aware expression without timezone unresolved
## What changes were proposed in this pull request? A cast expression with a resolved time zone is not equal to a cast expression without a resolved time zone. The `ResolveAggregateFunction` assumed that these expression were the same, and would fail to resolve `HAVING` clauses which contain a `Cast` expression. This is in essence caused by the fact that a `TimeZoneAwareExpression` can be resolved without a set time zone. This PR fixes this, and makes a `TimeZoneAwareExpression` unresolved as long as it has no TimeZone set. ## How was this patch tested? Added a regression test to the `SQLQueryTestSuite.having` file. Author: Herman van Hovell <hvanhovell@databricks.com> Closes #17641 from hvanhovell/SPARK-20329.
Diffstat (limited to 'sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala20
1 files changed, 12 insertions, 8 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 2d83d512e7..d307122b5c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QualifiedTableName, TableIdentifier}
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QualifiedTableName}
import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogUtils}
@@ -48,7 +48,7 @@ import org.apache.spark.unsafe.types.UTF8String
* Note that, this rule must be run after `PreprocessTableCreation` and
* `PreprocessTableInsertion`.
*/
-case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] {
+case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport {
def resolver: Resolver = conf.resolver
@@ -98,11 +98,11 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] {
val potentialSpecs = staticPartitions.filter {
case (partKey, partValue) => resolver(field.name, partKey)
}
- if (potentialSpecs.size == 0) {
+ if (potentialSpecs.isEmpty) {
None
} else if (potentialSpecs.size == 1) {
val partValue = potentialSpecs.head._2
- Some(Alias(Cast(Literal(partValue), field.dataType), field.name)())
+ Some(Alias(cast(Literal(partValue), field.dataType), field.name)())
} else {
throw new AnalysisException(
s"Partition column ${field.name} have multiple values specified, " +
@@ -258,7 +258,9 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
/**
* A Strategy for planning scans over data sources defined using the sources API.
*/
-object DataSourceStrategy extends Strategy with Logging {
+case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with CastSupport {
+ import DataSourceStrategy._
+
def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match {
case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _)) =>
pruneFilterProjectRaw(
@@ -298,7 +300,7 @@ object DataSourceStrategy extends Strategy with Logging {
// Restriction: Bucket pruning works iff the bucketing column has one and only one column.
def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = {
val mutableRow = new SpecificInternalRow(Seq(bucketColumn.dataType))
- mutableRow(0) = Cast(Literal(value), bucketColumn.dataType).eval(null)
+ mutableRow(0) = cast(Literal(value), bucketColumn.dataType).eval(null)
val bucketIdGeneration = UnsafeProjection.create(
HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil,
bucketColumn :: Nil)
@@ -436,7 +438,9 @@ object DataSourceStrategy extends Strategy with Logging {
private[this] def toCatalystRDD(relation: LogicalRelation, rdd: RDD[Row]): RDD[InternalRow] = {
toCatalystRDD(relation, relation.output, rdd)
}
+}
+object DataSourceStrategy {
/**
* Tries to translate a Catalyst [[Expression]] into data source [[Filter]].
*
@@ -527,8 +531,8 @@ object DataSourceStrategy extends Strategy with Logging {
* all [[Filter]]s that are completely filtered at the DataSource.
*/
protected[sql] def selectFilters(
- relation: BaseRelation,
- predicates: Seq[Expression]): (Seq[Expression], Seq[Filter], Set[Filter]) = {
+ relation: BaseRelation,
+ predicates: Seq[Expression]): (Seq[Expression], Seq[Filter], Set[Filter]) = {
// For conciseness, all Catalyst filter expressions of type `expressions.Expression` below are
// called `predicate`s, while all data source filters of type `sources.Filter` are simply called