From 2a7921a813ecd847fd933ffef10edc64684e9df7 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 3 Mar 2017 16:35:54 -0800 Subject: [SPARK-18939][SQL] Timezone support in partition values. ## What changes were proposed in this pull request? This is a follow-up pr of #16308 and #16750. This pr enables timezone support in partition values. We should use `timeZone` option introduced at #16750 to parse/format partition values of the `TimestampType`. For example, if you have timestamp `"2016-01-01 00:00:00"` in `GMT` which will be used for partition values, the values written by the default timezone option, which is `"GMT"` because the session local timezone is `"GMT"` here, are: ```scala scala> spark.conf.set("spark.sql.session.timeZone", "GMT") scala> val df = Seq((1, new java.sql.Timestamp(1451606400000L))).toDF("i", "ts") df: org.apache.spark.sql.DataFrame = [i: int, ts: timestamp] scala> df.show() +---+-------------------+ | i| ts| +---+-------------------+ | 1|2016-01-01 00:00:00| +---+-------------------+ scala> df.write.partitionBy("ts").save("/path/to/gmtpartition") ``` ```sh $ ls /path/to/gmtpartition/ _SUCCESS ts=2016-01-01 00%3A00%3A00 ``` whereas setting the option to `"PST"`, they are: ```scala scala> df.write.option("timeZone", "PST").partitionBy("ts").save("/path/to/pstpartition") ``` ```sh $ ls /path/to/pstpartition/ _SUCCESS ts=2015-12-31 16%3A00%3A00 ``` We can properly read the partition values if the session local timezone and the timezone of the partition values are the same: ```scala scala> spark.read.load("/path/to/gmtpartition").show() +---+-------------------+ | i| ts| +---+-------------------+ | 1|2016-01-01 00:00:00| +---+-------------------+ ``` And even if the timezones are different, we can properly read the values with setting corrent timezone option: ```scala // wrong result scala> spark.read.load("/path/to/pstpartition").show() +---+-------------------+ | i| ts| +---+-------------------+ | 1|2015-12-31 16:00:00| +---+-------------------+ // correct result scala> spark.read.option("timeZone", "PST").load("/path/to/pstpartition").show() +---+-------------------+ | i| ts| +---+-------------------+ | 1|2016-01-01 00:00:00| +---+-------------------+ ``` ## How was this patch tested? Existing tests and added some tests. Author: Takuya UESHIN Closes #17053 from ueshin/issues/SPARK-18939. --- .../sql/execution/OptimizeMetadataOnlyQuery.scala | 10 ++-- .../execution/datasources/CatalogFileIndex.scala | 3 +- .../execution/datasources/FileFormatWriter.scala | 18 ++++--- .../datasources/PartitioningAwareFileIndex.scala | 16 ++++-- .../execution/datasources/PartitioningUtils.scala | 42 +++++++++++---- .../sql/execution/datasources/csv/CSVSuite.scala | 15 ++++-- .../parquet/ParquetPartitionDiscoverySuite.scala | 62 +++++++++++++++++----- .../spark/sql/sources/PartitionedWriteSuite.scala | 35 ++++++++++++ 8 files changed, 155 insertions(+), 46 deletions(-) (limited to 'sql/core/src') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index b02edd4c74..aa578f4d23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.internal.SQLConf @@ -103,11 +103,13 @@ case class OptimizeMetadataOnlyQuery( case relation: CatalogRelation => val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) + val caseInsensitiveProperties = + CaseInsensitiveMap(relation.tableMeta.storage.properties) + val timeZoneId = caseInsensitiveProperties.get("timeZone") + .getOrElse(conf.sessionLocalTimeZone) val partitionData = catalog.listPartitions(relation.tableMeta.identifier).map { p => InternalRow.fromSeq(partAttrs.map { attr => - // TODO: use correct timezone for partition values. - Cast(Literal(p.spec(attr.name)), attr.dataType, - Option(DateTimeUtils.defaultTimeZone().getID)).eval() + Cast(Literal(p.spec(attr.name)), attr.dataType, Option(timeZoneId)).eval() }) } LocalRelation(partAttrs, partitionData) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index 1235a4b12f..2068811661 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -72,7 +72,8 @@ class CatalogFileIndex( val path = new Path(p.location) val fs = path.getFileSystem(hadoopConf) PartitionPath( - p.toRow(partitionSchema), path.makeQualified(fs.getUri, fs.getWorkingDirectory)) + p.toRow(partitionSchema, sparkSession.sessionState.conf.sessionLocalTimeZone), + path.makeQualified(fs.getUri, fs.getWorkingDirectory)) } val partitionSpec = PartitionSpec(partitionSchema, partitions) new PrunedInMemoryFileIndex( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index c17796811c..950e5ca0d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -68,7 +68,8 @@ object FileFormatWriter extends Logging { val bucketIdExpression: Option[Expression], val path: String, val customPartitionLocations: Map[TablePartitionSpec, String], - val maxRecordsPerFile: Long) + val maxRecordsPerFile: Long, + val timeZoneId: String) extends Serializable { assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns), @@ -122,9 +123,11 @@ object FileFormatWriter extends Logging { spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) } + val caseInsensitiveOptions = CaseInsensitiveMap(options) + // Note: prepareWrite has side effect. It sets "job". val outputWriterFactory = - fileFormat.prepareWrite(sparkSession, job, options, dataColumns.toStructType) + fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataColumns.toStructType) val description = new WriteJobDescription( uuid = UUID.randomUUID().toString, @@ -136,8 +139,10 @@ object FileFormatWriter extends Logging { bucketIdExpression = bucketIdExpression, path = outputSpec.outputPath, customPartitionLocations = outputSpec.customPartitionLocations, - maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong) - .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile) + maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong) + .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile), + timeZoneId = caseInsensitiveOptions.get("timeZone") + .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) ) // We should first sort by partition columns, then bucket id, and finally sorting columns. @@ -330,11 +335,10 @@ object FileFormatWriter extends Logging { /** Expressions that given partition columns build a path string like: col1=val/col2=val/... */ private def partitionPathExpression: Seq[Expression] = { desc.partitionColumns.zipWithIndex.flatMap { case (c, i) => - // TODO: use correct timezone for partition values. val escaped = ScalaUDF( ExternalCatalogUtils.escapePathName _, StringType, - Seq(Cast(c, StringType, Option(DateTimeUtils.defaultTimeZone().getID))), + Seq(Cast(c, StringType, Option(desc.timeZoneId))), Seq(StringType)) val str = If(IsNull(c), Literal(ExternalCatalogUtils.DEFAULT_PARTITION_NAME), escaped) val partitionName = Literal(c.name + "=") :: str :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 549257c0e1..c8097a7fab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -30,7 +30,7 @@ import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -125,22 +125,27 @@ abstract class PartitioningAwareFileIndex( val leafDirs = leafDirToChildrenFiles.filter { case (_, files) => files.exists(f => isDataPath(f.getPath)) }.keys.toSeq + + val caseInsensitiveOptions = CaseInsensitiveMap(parameters) + val timeZoneId = caseInsensitiveOptions.get("timeZone") + .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) + userPartitionSchema match { case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => val spec = PartitioningUtils.parsePartitions( leafDirs, typeInference = false, - basePaths = basePaths) + basePaths = basePaths, + timeZoneId = timeZoneId) // Without auto inference, all of value in the `row` should be null or in StringType, // we need to cast into the data type that user specified. def castPartitionValuesToUserSchema(row: InternalRow) = { InternalRow((0 until row.numFields).map { i => - // TODO: use correct timezone for partition values. Cast( Literal.create(row.getUTF8String(i), StringType), userProvidedSchema.fields(i).dataType, - Option(DateTimeUtils.defaultTimeZone().getID)).eval() + Option(timeZoneId)).eval() }: _*) } @@ -151,7 +156,8 @@ abstract class PartitioningAwareFileIndex( PartitioningUtils.parsePartitions( leafDirs, typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, - basePaths = basePaths) + basePaths = basePaths, + timeZoneId = timeZoneId) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index bad59961ac..09876bbc2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import java.lang.{Double => JDouble, Long => JLong} import java.math.{BigDecimal => JBigDecimal} -import java.sql.{Date => JDate, Timestamp => JTimestamp} +import java.util.TimeZone import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -31,7 +31,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String // TODO: We should tighten up visibility of the classes here once we clean up Hive coupling. @@ -91,10 +93,19 @@ object PartitioningUtils { private[datasources] def parsePartitions( paths: Seq[Path], typeInference: Boolean, - basePaths: Set[Path]): PartitionSpec = { + basePaths: Set[Path], + timeZoneId: String): PartitionSpec = { + parsePartitions(paths, typeInference, basePaths, TimeZone.getTimeZone(timeZoneId)) + } + + private[datasources] def parsePartitions( + paths: Seq[Path], + typeInference: Boolean, + basePaths: Set[Path], + timeZone: TimeZone): PartitionSpec = { // First, we need to parse every partition's path and see if we can find partition values. val (partitionValues, optDiscoveredBasePaths) = paths.map { path => - parsePartition(path, typeInference, basePaths) + parsePartition(path, typeInference, basePaths, timeZone) }.unzip // We create pairs of (path -> path's partition value) here @@ -173,7 +184,8 @@ object PartitioningUtils { private[datasources] def parsePartition( path: Path, typeInference: Boolean, - basePaths: Set[Path]): (Option[PartitionValues], Option[Path]) = { + basePaths: Set[Path], + timeZone: TimeZone): (Option[PartitionValues], Option[Path]) = { val columns = ArrayBuffer.empty[(String, Literal)] // Old Hadoop versions don't have `Path.isRoot` var finished = path.getParent == null @@ -194,7 +206,7 @@ object PartitioningUtils { // Let's say currentPath is a path of "/table/a=1/", currentPath.getName will give us a=1. // Once we get the string, we try to parse it and find the partition column and value. val maybeColumn = - parsePartitionColumn(currentPath.getName, typeInference) + parsePartitionColumn(currentPath.getName, typeInference, timeZone) maybeColumn.foreach(columns += _) // Now, we determine if we should stop. @@ -226,7 +238,8 @@ object PartitioningUtils { private def parsePartitionColumn( columnSpec: String, - typeInference: Boolean): Option[(String, Literal)] = { + typeInference: Boolean, + timeZone: TimeZone): Option[(String, Literal)] = { val equalSignIndex = columnSpec.indexOf('=') if (equalSignIndex == -1) { None @@ -237,7 +250,7 @@ object PartitioningUtils { val rawColumnValue = columnSpec.drop(equalSignIndex + 1) assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'") - val literal = inferPartitionColumnValue(rawColumnValue, typeInference) + val literal = inferPartitionColumnValue(rawColumnValue, typeInference, timeZone) Some(columnName -> literal) } } @@ -370,7 +383,8 @@ object PartitioningUtils { */ private[datasources] def inferPartitionColumnValue( raw: String, - typeInference: Boolean): Literal = { + typeInference: Boolean, + timeZone: TimeZone): Literal = { val decimalTry = Try { // `BigDecimal` conversion can fail when the `field` is not a form of number. val bigDecimal = new JBigDecimal(raw) @@ -390,8 +404,16 @@ object PartitioningUtils { // Then falls back to fractional types .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) // Then falls back to date/timestamp types - .orElse(Try(Literal(JDate.valueOf(raw)))) - .orElse(Try(Literal(JTimestamp.valueOf(unescapePathName(raw))))) + .orElse(Try( + Literal.create( + DateTimeUtils.getThreadLocalTimestampFormat(timeZone) + .parse(unescapePathName(raw)).getTime * 1000L, + TimestampType))) + .orElse(Try( + Literal.create( + DateTimeUtils.millisToDays( + DateTimeUtils.getThreadLocalDateFormat.parse(raw).getTime), + DateType))) // Then falls back to string .getOrElse { if (raw == DEFAULT_PARTITION_NAME) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index d94eb66201..56071803f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -742,10 +742,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .save(iso8601timestampsPath) // This will load back the timestamps as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) val iso8601Timestamps = spark.read .format("csv") + .schema(stringSchema) .option("header", "true") - .option("inferSchema", "false") .load(iso8601timestampsPath) val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSZZ", Locale.US) @@ -775,10 +776,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .save(iso8601datesPath) // This will load back the dates as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) val iso8601dates = spark.read .format("csv") + .schema(stringSchema) .option("header", "true") - .option("inferSchema", "false") .load(iso8601datesPath) val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd", Locale.US) @@ -833,10 +835,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .save(datesWithFormatPath) // This will load back the dates as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) val stringDatesWithFormat = spark.read .format("csv") + .schema(stringSchema) .option("header", "true") - .option("inferSchema", "false") .load(datesWithFormatPath) val expectedStringDatesWithFormat = Seq( Row("2015/08/26"), @@ -864,10 +867,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .save(timestampsWithFormatPath) // This will load back the timestamps as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) val stringTimestampsWithFormat = spark.read .format("csv") + .schema(stringSchema) .option("header", "true") - .option("inferSchema", "false") .load(timestampsWithFormatPath) val expectedStringTimestampsWithFormat = Seq( Row("2015/08/26 18:00"), @@ -896,10 +900,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .save(timestampsWithFormatPath) // This will load back the timestamps as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) val stringTimestampsWithFormat = spark.read .format("csv") + .schema(stringSchema) .option("header", "true") - .option("inferSchema", "false") .load(timestampsWithFormatPath) val expectedStringTimestampsWithFormat = Seq( Row("2015/08/27 01:00"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 420cff878f..88cb8a0bad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File import java.math.BigInteger import java.sql.{Date, Timestamp} +import java.util.{Calendar, TimeZone} import scala.collection.mutable.ArrayBuffer @@ -51,9 +52,12 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val defaultPartitionName = ExternalCatalogUtils.DEFAULT_PARTITION_NAME + val timeZone = TimeZone.getDefault() + val timeZoneId = timeZone.getID + test("column type inference") { - def check(raw: String, literal: Literal): Unit = { - assert(inferPartitionColumnValue(raw, true) === literal) + def check(raw: String, literal: Literal, timeZone: TimeZone = timeZone): Unit = { + assert(inferPartitionColumnValue(raw, true, timeZone) === literal) } check("10", Literal.create(10, IntegerType)) @@ -66,6 +70,14 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha check("1990-02-24", Literal.create(Date.valueOf("1990-02-24"), DateType)) check("1990-02-24 12:00:30", Literal.create(Timestamp.valueOf("1990-02-24 12:00:30"), TimestampType)) + + val c = Calendar.getInstance(TimeZone.getTimeZone("GMT")) + c.set(1990, 1, 24, 12, 0, 30) + c.set(Calendar.MILLISECOND, 0) + check("1990-02-24 12:00:30", + Literal.create(new Timestamp(c.getTimeInMillis), TimestampType), + TimeZone.getTimeZone("GMT")) + check(defaultPartitionName, Literal.create(null, NullType)) } @@ -77,7 +89,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha "hdfs://host:9000/path/a=10.5/b=hello") var exception = intercept[AssertionError] { - parsePartitions(paths.map(new Path(_)), true, Set.empty[Path]) + parsePartitions(paths.map(new Path(_)), true, Set.empty[Path], timeZoneId) } assert(exception.getMessage().contains("Conflicting directory structures detected")) @@ -90,7 +102,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), true, - Set(new Path("hdfs://host:9000/path/"))) + Set(new Path("hdfs://host:9000/path/")), + timeZoneId) // Valid paths = Seq( @@ -102,7 +115,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), true, - Set(new Path("hdfs://host:9000/path/something=true/table"))) + Set(new Path("hdfs://host:9000/path/something=true/table")), + timeZoneId) // Valid paths = Seq( @@ -114,7 +128,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), true, - Set(new Path("hdfs://host:9000/path/table=true"))) + Set(new Path("hdfs://host:9000/path/table=true")), + timeZoneId) // Invalid paths = Seq( @@ -126,7 +141,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), true, - Set(new Path("hdfs://host:9000/path/"))) + Set(new Path("hdfs://host:9000/path/")), + timeZoneId) } assert(exception.getMessage().contains("Conflicting directory structures detected")) @@ -145,20 +161,21 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), true, - Set(new Path("hdfs://host:9000/tmp/tables/"))) + Set(new Path("hdfs://host:9000/tmp/tables/")), + timeZoneId) } assert(exception.getMessage().contains("Conflicting directory structures detected")) } test("parse partition") { def check(path: String, expected: Option[PartitionValues]): Unit = { - val actual = parsePartition(new Path(path), true, Set.empty[Path])._1 + val actual = parsePartition(new Path(path), true, Set.empty[Path], timeZone)._1 assert(expected === actual) } def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = { val message = intercept[T] { - parsePartition(new Path(path), true, Set.empty[Path]) + parsePartition(new Path(path), true, Set.empty[Path], timeZone) }.getMessage assert(message.contains(expected)) @@ -201,7 +218,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val partitionSpec1: Option[PartitionValues] = parsePartition( path = new Path("file://path/a=10"), typeInference = true, - basePaths = Set(new Path("file://path/a=10")))._1 + basePaths = Set(new Path("file://path/a=10")), + timeZone = timeZone)._1 assert(partitionSpec1.isEmpty) @@ -209,7 +227,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val partitionSpec2: Option[PartitionValues] = parsePartition( path = new Path("file://path/a=10"), typeInference = true, - basePaths = Set(new Path("file://path")))._1 + basePaths = Set(new Path("file://path")), + timeZone = timeZone)._1 assert(partitionSpec2 == Option(PartitionValues( @@ -226,7 +245,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), true, - rootPaths) + rootPaths, + timeZoneId) assert(actualSpec === spec) } @@ -307,7 +327,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha test("parse partitions with type inference disabled") { def check(paths: Seq[String], spec: PartitionSpec): Unit = { val actualSpec = - parsePartitions(paths.map(new Path(_)), false, Set.empty[Path]) + parsePartitions(paths.map(new Path(_)), false, Set.empty[Path], timeZoneId) assert(actualSpec === spec) } @@ -686,6 +706,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val fields = schema.map(f => Column(f.name).cast(f.dataType)) checkAnswer(spark.read.load(dir.toString).select(fields: _*), row) } + + withTempPath { dir => + df.write.option("timeZone", "GMT") + .format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) + val fields = schema.map(f => Column(f.name).cast(f.dataType)) + checkAnswer(spark.read.option("timeZone", "GMT").load(dir.toString).select(fields: _*), row) + } } test("Various inferred partition value types") { @@ -720,6 +747,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val fields = schema.map(f => Column(f.name)) checkAnswer(spark.read.load(dir.toString).select(fields: _*), row) } + + withTempPath { dir => + df.write.option("timeZone", "GMT") + .format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) + val fields = schema.map(f => Column(f.name)) + checkAnswer(spark.read.option("timeZone", "GMT").load(dir.toString).select(fields: _*), row) + } } test("SPARK-8037: Ignores files whose name starts with dot") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index bf7fabe332..f251290583 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -18,11 +18,13 @@ package org.apache.spark.sql.sources import java.io.File +import java.sql.Timestamp import org.apache.hadoop.mapreduce.TaskAttemptContext import org.apache.spark.internal.Logging import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -124,6 +126,39 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { } } + test("timeZone setting in dynamic partition writes") { + def checkPartitionValues(file: File, expected: String): Unit = { + val dir = file.getParentFile() + val value = ExternalCatalogUtils.unescapePathName( + dir.getName.substring(dir.getName.indexOf("=") + 1)) + assert(value == expected) + } + val ts = Timestamp.valueOf("2016-12-01 00:00:00") + val df = Seq((1, ts)).toDF("i", "ts") + withTempPath { f => + df.write.partitionBy("ts").parquet(f.getAbsolutePath) + val files = recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + assert(files.length == 1) + checkPartitionValues(files.head, "2016-12-01 00:00:00") + } + withTempPath { f => + df.write.option("timeZone", "GMT").partitionBy("ts").parquet(f.getAbsolutePath) + val files = recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + assert(files.length == 1) + // use timeZone option "GMT" to format partition value. + checkPartitionValues(files.head, "2016-12-01 08:00:00") + } + withTempPath { f => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") { + df.write.partitionBy("ts").parquet(f.getAbsolutePath) + val files = recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + assert(files.length == 1) + // if there isn't timeZone option, then use session local timezone. + checkPartitionValues(files.head, "2016-12-01 08:00:00") + } + } + } + /** Lists files recursively. */ private def recursiveList(f: File): Array[File] = { require(f.isDirectory) -- cgit v1.2.3