diff options
Diffstat (limited to 'sql/core/src/main/scala')
4 files changed, 17 insertions, 13 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 6566502bd8..4e718d609c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -36,7 +36,7 @@ class SparkPlanner( experimentalMethods.extraStrategies ++ extraPlanningStrategies ++ ( FileSourceStrategy :: - DataSourceStrategy :: + DataSourceStrategy(conf) :: SpecialLimits :: Aggregation :: JoinSelection :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2d83d512e7..d307122b5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QualifiedTableName, TableIdentifier} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QualifiedTableName} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogUtils} @@ -48,7 +48,7 @@ import org.apache.spark.unsafe.types.UTF8String * Note that, this rule must be run after `PreprocessTableCreation` and * `PreprocessTableInsertion`. */ -case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] { +case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { def resolver: Resolver = conf.resolver @@ -98,11 +98,11 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] { val potentialSpecs = staticPartitions.filter { case (partKey, partValue) => resolver(field.name, partKey) } - if (potentialSpecs.size == 0) { + if (potentialSpecs.isEmpty) { None } else if (potentialSpecs.size == 1) { val partValue = potentialSpecs.head._2 - Some(Alias(Cast(Literal(partValue), field.dataType), field.name)()) + Some(Alias(cast(Literal(partValue), field.dataType), field.name)()) } else { throw new AnalysisException( s"Partition column ${field.name} have multiple values specified, " + @@ -258,7 +258,9 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] /** * A Strategy for planning scans over data sources defined using the sources API. */ -object DataSourceStrategy extends Strategy with Logging { +case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with CastSupport { + import DataSourceStrategy._ + def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _)) => pruneFilterProjectRaw( @@ -298,7 +300,7 @@ object DataSourceStrategy extends Strategy with Logging { // Restriction: Bucket pruning works iff the bucketing column has one and only one column. def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { val mutableRow = new SpecificInternalRow(Seq(bucketColumn.dataType)) - mutableRow(0) = Cast(Literal(value), bucketColumn.dataType).eval(null) + mutableRow(0) = cast(Literal(value), bucketColumn.dataType).eval(null) val bucketIdGeneration = UnsafeProjection.create( HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil, bucketColumn :: Nil) @@ -436,7 +438,9 @@ object DataSourceStrategy extends Strategy with Logging { private[this] def toCatalystRDD(relation: LogicalRelation, rdd: RDD[Row]): RDD[InternalRow] = { toCatalystRDD(relation, relation.output, rdd) } +} +object DataSourceStrategy { /** * Tries to translate a Catalyst [[Expression]] into data source [[Filter]]. * @@ -527,8 +531,8 @@ object DataSourceStrategy extends Strategy with Logging { * all [[Filter]]s that are completely filtered at the DataSource. */ protected[sql] def selectFilters( - relation: BaseRelation, - predicates: Seq[Expression]): (Seq[Expression], Seq[Filter], Set[Filter]) = { + relation: BaseRelation, + predicates: Seq[Expression]): (Seq[Expression], Seq[Filter], Set[Filter]) = { // For conciseness, all Catalyst filter expressions of type `expressions.Expression` below are // called `predicate`s, while all data source filters of type `sources.Filter` are simply called diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 7abf2ae516..3f4a78580f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -22,7 +22,7 @@ import java.util.Locale import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.command.DDLUtils @@ -315,7 +315,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi * table. It also does data type casting and field renaming, to make sure that the columns to be * inserted have the correct data type and fields have the correct names. */ -case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { +case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { private def preprocess( insert: InsertIntoTable, tblName: String, @@ -367,7 +367,7 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { // Renaming is needed for handling the following cases like // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 // 2) Target tables have column metadata - Alias(Cast(actual, expected.dataType), expected.name)( + Alias(cast(actual, expected.dataType), expected.name)( explicitMetadata = Option(expected.metadata)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 2b14eca919..df7c3678b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkConf import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy, UDFRegistration} -import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry, ResolveTimeZone} import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface |