aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main/scala
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src/main/scala')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala2
4 files changed, 17 insertions, 13 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
index 6566502bd8..4e718d609c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
@@ -36,7 +36,7 @@ class SparkPlanner(
experimentalMethods.extraStrategies ++
extraPlanningStrategies ++ (
FileSourceStrategy ::
- DataSourceStrategy ::
+ DataSourceStrategy(conf) ::
SpecialLimits ::
Aggregation ::
JoinSelection ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 2d83d512e7..d307122b5c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QualifiedTableName, TableIdentifier}
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QualifiedTableName}
import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogUtils}
@@ -48,7 +48,7 @@ import org.apache.spark.unsafe.types.UTF8String
* Note that, this rule must be run after `PreprocessTableCreation` and
* `PreprocessTableInsertion`.
*/
-case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] {
+case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport {
def resolver: Resolver = conf.resolver
@@ -98,11 +98,11 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] {
val potentialSpecs = staticPartitions.filter {
case (partKey, partValue) => resolver(field.name, partKey)
}
- if (potentialSpecs.size == 0) {
+ if (potentialSpecs.isEmpty) {
None
} else if (potentialSpecs.size == 1) {
val partValue = potentialSpecs.head._2
- Some(Alias(Cast(Literal(partValue), field.dataType), field.name)())
+ Some(Alias(cast(Literal(partValue), field.dataType), field.name)())
} else {
throw new AnalysisException(
s"Partition column ${field.name} have multiple values specified, " +
@@ -258,7 +258,9 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
/**
* A Strategy for planning scans over data sources defined using the sources API.
*/
-object DataSourceStrategy extends Strategy with Logging {
+case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with CastSupport {
+ import DataSourceStrategy._
+
def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match {
case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _)) =>
pruneFilterProjectRaw(
@@ -298,7 +300,7 @@ object DataSourceStrategy extends Strategy with Logging {
// Restriction: Bucket pruning works iff the bucketing column has one and only one column.
def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = {
val mutableRow = new SpecificInternalRow(Seq(bucketColumn.dataType))
- mutableRow(0) = Cast(Literal(value), bucketColumn.dataType).eval(null)
+ mutableRow(0) = cast(Literal(value), bucketColumn.dataType).eval(null)
val bucketIdGeneration = UnsafeProjection.create(
HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil,
bucketColumn :: Nil)
@@ -436,7 +438,9 @@ object DataSourceStrategy extends Strategy with Logging {
private[this] def toCatalystRDD(relation: LogicalRelation, rdd: RDD[Row]): RDD[InternalRow] = {
toCatalystRDD(relation, relation.output, rdd)
}
+}
+object DataSourceStrategy {
/**
* Tries to translate a Catalyst [[Expression]] into data source [[Filter]].
*
@@ -527,8 +531,8 @@ object DataSourceStrategy extends Strategy with Logging {
* all [[Filter]]s that are completely filtered at the DataSource.
*/
protected[sql] def selectFilters(
- relation: BaseRelation,
- predicates: Seq[Expression]): (Seq[Expression], Seq[Filter], Set[Filter]) = {
+ relation: BaseRelation,
+ predicates: Seq[Expression]): (Seq[Expression], Seq[Filter], Set[Filter]) = {
// For conciseness, all Catalyst filter expressions of type `expressions.Expression` below are
// called `predicate`s, while all data source filters of type `sources.Filter` are simply called
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index 7abf2ae516..3f4a78580f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -22,7 +22,7 @@ import java.util.Locale
import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog._
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, RowOrdering}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.command.DDLUtils
@@ -315,7 +315,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi
* table. It also does data type casting and field renaming, to make sure that the columns to be
* inserted have the correct data type and fields have the correct names.
*/
-case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] {
+case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport {
private def preprocess(
insert: InsertIntoTable,
tblName: String,
@@ -367,7 +367,7 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] {
// Renaming is needed for handling the following cases like
// 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2
// 2) Target tables have column metadata
- Alias(Cast(actual, expected.dataType), expected.name)(
+ Alias(cast(actual, expected.dataType), expected.name)(
explicitMetadata = Option(expected.metadata))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 2b14eca919..df7c3678b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.internal
import org.apache.spark.SparkConf
import org.apache.spark.annotation.{Experimental, InterfaceStability}
import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy, UDFRegistration}
-import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry}
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry, ResolveTimeZone}
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.parser.ParserInterface