diff options
author | Takuya UESHIN <ueshin@happy-camper.st> | 2017-01-26 11:51:05 +0100 |
---|---|---|
committer | Herman van Hovell <hvanhovell@databricks.com> | 2017-01-26 11:51:05 +0100 |
commit | 2969fb4370120a39dae98be716b24dcc0ada2cef (patch) | |
tree | abb577f50071ed97d2934db8a896e58e30907309 /sql/core/src/main/scala/org/apache | |
parent | 7045b8b3554459fe61c4b32868560e66444a2876 (diff) | |
download | spark-2969fb4370120a39dae98be716b24dcc0ada2cef.tar.gz spark-2969fb4370120a39dae98be716b24dcc0ada2cef.tar.bz2 spark-2969fb4370120a39dae98be716b24dcc0ada2cef.zip |
[SPARK-18936][SQL] Infrastructure for session local timezone support.
## What changes were proposed in this pull request?
As of Spark 2.1, Spark SQL assumes the machine timezone for datetime manipulation, which is bad if users are not in the same timezones as the machines, or if different users have different timezones.
We should introduce a session local timezone setting that is used for execution.
An explicit non-goal is locale handling.
### Semantics
Setting the session local timezone means that the timezone-aware expressions listed below should use the timezone to evaluate values, and also it should be used to convert (cast) between string and timestamp or between timestamp and date.
- `CurrentDate`
- `CurrentBatchTimestamp`
- `Hour`
- `Minute`
- `Second`
- `DateFormatClass`
- `ToUnixTimestamp`
- `UnixTimestamp`
- `FromUnixTime`
and below are implicitly timezone-aware through cast from timestamp to date:
- `DayOfYear`
- `Year`
- `Quarter`
- `Month`
- `DayOfMonth`
- `WeekOfYear`
- `LastDay`
- `NextDay`
- `TruncDate`
For example, if you have timestamp `"2016-01-01 00:00:00"` in `GMT`, the values evaluated by some of timezone-aware expressions are:
```scala
scala> val df = Seq(new java.sql.Timestamp(1451606400000L)).toDF("ts")
df: org.apache.spark.sql.DataFrame = [ts: timestamp]
scala> df.selectExpr("cast(ts as string)", "year(ts)", "month(ts)", "dayofmonth(ts)", "hour(ts)", "minute(ts)", "second(ts)").show(truncate = false)
+-------------------+----------------------+-----------------------+----------------------------+--------+----------+----------+
|ts |year(CAST(ts AS DATE))|month(CAST(ts AS DATE))|dayofmonth(CAST(ts AS DATE))|hour(ts)|minute(ts)|second(ts)|
+-------------------+----------------------+-----------------------+----------------------------+--------+----------+----------+
|2016-01-01 00:00:00|2016 |1 |1 |0 |0 |0 |
+-------------------+----------------------+-----------------------+----------------------------+--------+----------+----------+
```
whereas setting the session local timezone to `"PST"`, they are:
```scala
scala> spark.conf.set("spark.sql.session.timeZone", "PST")
scala> df.selectExpr("cast(ts as string)", "year(ts)", "month(ts)", "dayofmonth(ts)", "hour(ts)", "minute(ts)", "second(ts)").show(truncate = false)
+-------------------+----------------------+-----------------------+----------------------------+--------+----------+----------+
|ts |year(CAST(ts AS DATE))|month(CAST(ts AS DATE))|dayofmonth(CAST(ts AS DATE))|hour(ts)|minute(ts)|second(ts)|
+-------------------+----------------------+-----------------------+----------------------------+--------+----------+----------+
|2015-12-31 16:00:00|2015 |12 |31 |16 |0 |0 |
+-------------------+----------------------+-----------------------+----------------------------+--------+----------+----------+
```
Notice that even if you set the session local timezone, it affects only in `DataFrame` operations, neither in `Dataset` operations, `RDD` operations nor in `ScalaUDF`s. You need to properly handle timezone by yourself.
### Design of the fix
I introduced an analyzer to pass session local timezone to timezone-aware expressions and modified DateTimeUtils to take the timezone argument.
## How was this patch tested?
Existing tests and added tests for timezone aware expressions.
Author: Takuya UESHIN <ueshin@happy-camper.st>
Closes #16308 from ueshin/issues/SPARK-18350.
Diffstat (limited to 'sql/core/src/main/scala/org/apache')
10 files changed, 40 insertions, 28 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 60182befd7..38029552d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -174,7 +174,7 @@ class Column(val expr: Expression) extends Logging { // NamedExpression under this Cast. case c: Cast => c.transformUp { - case Cast(ne: NamedExpression, to) => UnresolvedAlias(Cast(ne, to)) + case c @ Cast(_: NamedExpression, _, _) => UnresolvedAlias(c) } match { case ne: NamedExpression => ne case other => Alias(expr, usePrettyExpression(expr).sql)() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 5ee173f72e..391c34f128 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql import java.io.CharArrayWriter +import java.sql.{Date, Timestamp} +import java.util.TimeZone import scala.collection.JavaConverters._ import scala.language.implicitConversions @@ -43,7 +45,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} -import org.apache.spark.sql.catalyst.util.usePrettyExpression +import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DateTimeUtils} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand, GlobalTempView, LocalTempView} import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -250,6 +252,8 @@ class Dataset[T] private[sql]( val hasMoreData = takeResult.length > numRows val data = takeResult.take(numRows) + lazy val timeZone = TimeZone.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone) + // For array values, replace Seq and Array with square brackets // For cells that are beyond `truncate` characters, replace it with the // first `truncate-3` and "..." @@ -260,6 +264,10 @@ class Dataset[T] private[sql]( case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]") case array: Array[_] => array.mkString("[", ", ", "]") case seq: Seq[_] => seq.mkString("[", ", ", "]") + case d: Date => + DateTimeUtils.dateToString(DateTimeUtils.fromJavaDate(d)) + case ts: Timestamp => + DateTimeUtils.timestampToString(DateTimeUtils.fromJavaTimestamp(ts), timeZone) case _ => cell.toString } if (truncate > 0 && str.length > truncate) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala index 0384c0f236..d5a8566d07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala @@ -369,7 +369,7 @@ class SQLBuilder private ( case ar: AttributeReference if groupByAttrMap.contains(ar) => groupByAttrMap(ar) case a @ Cast(BitwiseAnd( ShiftRight(ar: AttributeReference, Literal(value: Any, IntegerType)), - Literal(1, IntegerType)), ByteType) if ar == gid => + Literal(1, IntegerType)), ByteType, _) if ar == gid => // for converting an expression to its original SQL format grouping(col) val idx = groupByExprs.length - 1 - value.asInstanceOf[Int] groupByExprs.lift(idx).map(Grouping).getOrElse(a) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index 1b7fedca84..b8ac070e3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.internal.SQLConf @@ -104,7 +105,9 @@ case class OptimizeMetadataOnlyQuery( val partAttrs = getPartitionAttrs(relation.catalogTable.partitionColumnNames, relation) val partitionData = catalog.listPartitions(relation.catalogTable.identifier).map { p => InternalRow.fromSeq(partAttrs.map { attr => - Cast(Literal(p.spec(attr.name)), attr.dataType).eval() + // TODO: use correct timezone for partition values. + Cast(Literal(p.spec(attr.name)), attr.dataType, + Option(DateTimeUtils.defaultTimeZone().getID)).eval() }) } LocalRelation(partAttrs, partitionData) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index dcd9003ec6..9d046c0766 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution import java.nio.charset.StandardCharsets -import java.sql.Timestamp +import java.sql.{Date, Timestamp} +import java.util.TimeZone import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SparkSession} @@ -139,22 +140,6 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, ShortType, DateType, TimestampType, BinaryType) - /** Implementation following Hive's TimestampWritable.toString */ - def formatTimestamp(timestamp: Timestamp): String = { - val timestampString = timestamp.toString - if (timestampString.length() > 19) { - if (timestampString.length() == 21) { - if (timestampString.substring(19).compareTo(".0") == 0) { - return DateTimeUtils.threadLocalTimestampFormat.get().format(timestamp) - } - } - return DateTimeUtils.threadLocalTimestampFormat.get().format(timestamp) + - timestampString.substring(19) - } - - return DateTimeUtils.threadLocalTimestampFormat.get().format(timestamp) - } - def formatDecimal(d: java.math.BigDecimal): String = { if (d.compareTo(java.math.BigDecimal.ZERO) == 0) { java.math.BigDecimal.ZERO.toPlainString @@ -195,8 +180,11 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) }.toSeq.sorted.mkString("{", ",", "}") case (null, _) => "NULL" - case (d: Int, DateType) => new java.util.Date(DateTimeUtils.daysToMillis(d)).toString - case (t: Timestamp, TimestampType) => formatTimestamp(t) + case (d: Date, DateType) => + DateTimeUtils.dateToString(DateTimeUtils.fromJavaDate(d)) + case (t: Timestamp, TimestampType) => + DateTimeUtils.timestampToString(DateTimeUtils.fromJavaTimestamp(t), + TimeZone.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone)) case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8) case (decimal: java.math.BigDecimal, DecimalType()) => formatDecimal(decimal) case (other, tpe) if primitiveTypes.contains(tpe) => other.toString diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 16c5193eda..be13cbc51a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, UnsafeKVExternalSorter} import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -311,10 +312,11 @@ object FileFormatWriter extends Logging { /** Expressions that given a partition key build a string like: col1=val/col2=val/... */ private def partitionStringExpression: Seq[Expression] = { description.partitionColumns.zipWithIndex.flatMap { case (c, i) => + // TODO: use correct timezone for partition values. val escaped = ScalaUDF( ExternalCatalogUtils.escapePathName _, StringType, - Seq(Cast(c, StringType)), + Seq(Cast(c, StringType, Option(DateTimeUtils.defaultTimeZone().getID))), Seq(StringType)) val str = If(IsNull(c), Literal(ExternalCatalogUtils.DEFAULT_PARTITION_NAME), escaped) val partitionName = Literal(c.name + "=") :: str :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index fe9c6578b1..75f87a5503 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -30,6 +30,7 @@ import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -135,9 +136,11 @@ abstract class PartitioningAwareFileIndex( // we need to cast into the data type that user specified. def castPartitionValuesToUserSchema(row: InternalRow) = { InternalRow((0 until row.numFields).map { i => + // TODO: use correct timezone for partition values. Cast( Literal.create(row.getUTF8String(i), StringType), - userProvidedSchema.fields(i).dataType).eval() + userProvidedSchema.fields(i).dataType, + Option(DateTimeUtils.defaultTimeZone().getID)).eval() }: _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 6ab6fa61dc..bd7cec3917 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -58,7 +58,7 @@ class IncrementalExecution( */ override lazy val optimizedPlan: LogicalPlan = { sparkSession.sessionState.optimizer.execute(withCachedData) transformAllExpressions { - case ts @ CurrentBatchTimestamp(timestamp, _) => + case ts @ CurrentBatchTimestamp(timestamp, _, _) => logInfo(s"Current batch timestamp = $timestamp") ts.toLiteral } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index a35950e2dc..ea3719421b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -502,7 +502,7 @@ class StreamExecution( ct.dataType) case cd: CurrentDate => CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs, - cd.dataType) + cd.dataType, cd.timeZoneId) } reportTimeTaken("queryPlanning") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d0c86ffc27..5ba4192512 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.internal -import java.util.{NoSuchElementException, Properties} +import java.util.{NoSuchElementException, Properties, TimeZone} import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ @@ -660,6 +660,12 @@ object SQLConf { .booleanConf .createWithDefault(false) + val SESSION_LOCAL_TIMEZONE = + SQLConfigBuilder("spark.sql.session.timeZone") + .doc("""The ID of session local timezone, e.g. "GMT", "America/Los_Angeles", etc.""") + .stringConf + .createWithDefault(TimeZone.getDefault().getID()) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -858,6 +864,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { override def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED) + override def sessionLocalTimeZone: String = getConf(SQLConf.SESSION_LOCAL_TIMEZONE) + def ndvMaxError: Double = getConf(NDV_MAX_ERROR) override def cboEnabled: Boolean = getConf(SQLConf.CBO_ENABLED) |