aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala')
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala16
1 files changed, 10 insertions, 6 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala
index b16c9f8fc9..735e07c213 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Expression, Literal}
import org.apache.spark.sql.execution.datasources.DataSourceAnalysis
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.sql.types.{DataType, IntegerType, StructType}
class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll {
@@ -49,7 +49,11 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll {
}
Seq(true, false).foreach { caseSensitive =>
- val rule = DataSourceAnalysis(new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive))
+ val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive)
+ def cast(e: Expression, dt: DataType): Expression = {
+ Cast(e, dt, Option(conf.sessionLocalTimeZone))
+ }
+ val rule = DataSourceAnalysis(conf)
test(
s"convertStaticPartitions only handle INSERT having at least static partitions " +
s"(caseSensitive: $caseSensitive)") {
@@ -150,7 +154,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll {
if (!caseSensitive) {
val nonPartitionedAttributes = Seq('e.int, 'f.int)
val expected = nonPartitionedAttributes ++
- Seq(Cast(Literal("1"), IntegerType), Cast(Literal("3"), IntegerType))
+ Seq(cast(Literal("1"), IntegerType), cast(Literal("3"), IntegerType))
val actual = rule.convertStaticPartitions(
sourceAttributes = nonPartitionedAttributes,
providedPartitions = Map("b" -> Some("1"), "C" -> Some("3")),
@@ -162,7 +166,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll {
{
val nonPartitionedAttributes = Seq('e.int, 'f.int)
val expected = nonPartitionedAttributes ++
- Seq(Cast(Literal("1"), IntegerType), Cast(Literal("3"), IntegerType))
+ Seq(cast(Literal("1"), IntegerType), cast(Literal("3"), IntegerType))
val actual = rule.convertStaticPartitions(
sourceAttributes = nonPartitionedAttributes,
providedPartitions = Map("b" -> Some("1"), "c" -> Some("3")),
@@ -174,7 +178,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll {
// Test the case having a single static partition column.
{
val nonPartitionedAttributes = Seq('e.int, 'f.int)
- val expected = nonPartitionedAttributes ++ Seq(Cast(Literal("1"), IntegerType))
+ val expected = nonPartitionedAttributes ++ Seq(cast(Literal("1"), IntegerType))
val actual = rule.convertStaticPartitions(
sourceAttributes = nonPartitionedAttributes,
providedPartitions = Map("b" -> Some("1")),
@@ -189,7 +193,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll {
val dynamicPartitionAttributes = Seq('g.int)
val expected =
nonPartitionedAttributes ++
- Seq(Cast(Literal("1"), IntegerType)) ++
+ Seq(cast(Literal("1"), IntegerType)) ++
dynamicPartitionAttributes
val actual = rule.convertStaticPartitions(
sourceAttributes = nonPartitionedAttributes ++ dynamicPartitionAttributes,