aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorTakuya UESHIN <ueshin@happy-camper.st>2017-01-26 11:51:05 +0100
committerHerman van Hovell <hvanhovell@databricks.com>2017-01-26 11:51:05 +0100
commit2969fb4370120a39dae98be716b24dcc0ada2cef (patch)
treeabb577f50071ed97d2934db8a896e58e30907309 /sql/core
parent7045b8b3554459fe61c4b32868560e66444a2876 (diff)
downloadspark-2969fb4370120a39dae98be716b24dcc0ada2cef.tar.gz
spark-2969fb4370120a39dae98be716b24dcc0ada2cef.tar.bz2
spark-2969fb4370120a39dae98be716b24dcc0ada2cef.zip
[SPARK-18936][SQL] Infrastructure for session local timezone support.
## What changes were proposed in this pull request? As of Spark 2.1, Spark SQL assumes the machine timezone for datetime manipulation, which is bad if users are not in the same timezones as the machines, or if different users have different timezones. We should introduce a session local timezone setting that is used for execution. An explicit non-goal is locale handling. ### Semantics Setting the session local timezone means that the timezone-aware expressions listed below should use the timezone to evaluate values, and also it should be used to convert (cast) between string and timestamp or between timestamp and date. - `CurrentDate` - `CurrentBatchTimestamp` - `Hour` - `Minute` - `Second` - `DateFormatClass` - `ToUnixTimestamp` - `UnixTimestamp` - `FromUnixTime` and below are implicitly timezone-aware through cast from timestamp to date: - `DayOfYear` - `Year` - `Quarter` - `Month` - `DayOfMonth` - `WeekOfYear` - `LastDay` - `NextDay` - `TruncDate` For example, if you have timestamp `"2016-01-01 00:00:00"` in `GMT`, the values evaluated by some of timezone-aware expressions are: ```scala scala> val df = Seq(new java.sql.Timestamp(1451606400000L)).toDF("ts") df: org.apache.spark.sql.DataFrame = [ts: timestamp] scala> df.selectExpr("cast(ts as string)", "year(ts)", "month(ts)", "dayofmonth(ts)", "hour(ts)", "minute(ts)", "second(ts)").show(truncate = false) +-------------------+----------------------+-----------------------+----------------------------+--------+----------+----------+ |ts |year(CAST(ts AS DATE))|month(CAST(ts AS DATE))|dayofmonth(CAST(ts AS DATE))|hour(ts)|minute(ts)|second(ts)| +-------------------+----------------------+-----------------------+----------------------------+--------+----------+----------+ |2016-01-01 00:00:00|2016 |1 |1 |0 |0 |0 | +-------------------+----------------------+-----------------------+----------------------------+--------+----------+----------+ ``` whereas setting the session local timezone to `"PST"`, they are: ```scala scala> spark.conf.set("spark.sql.session.timeZone", "PST") scala> df.selectExpr("cast(ts as string)", "year(ts)", "month(ts)", "dayofmonth(ts)", "hour(ts)", "minute(ts)", "second(ts)").show(truncate = false) +-------------------+----------------------+-----------------------+----------------------------+--------+----------+----------+ |ts |year(CAST(ts AS DATE))|month(CAST(ts AS DATE))|dayofmonth(CAST(ts AS DATE))|hour(ts)|minute(ts)|second(ts)| +-------------------+----------------------+-----------------------+----------------------------+--------+----------+----------+ |2015-12-31 16:00:00|2015 |12 |31 |16 |0 |0 | +-------------------+----------------------+-----------------------+----------------------------+--------+----------+----------+ ``` Notice that even if you set the session local timezone, it affects only in `DataFrame` operations, neither in `Dataset` operations, `RDD` operations nor in `ScalaUDF`s. You need to properly handle timezone by yourself. ### Design of the fix I introduced an analyzer to pass session local timezone to timezone-aware expressions and modified DateTimeUtils to take the timezone argument. ## How was this patch tested? Existing tests and added tests for timezone aware expressions. Author: Takuya UESHIN <ueshin@happy-camper.st> Closes #16308 from ueshin/issues/SPARK-18350.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala26
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala25
11 files changed, 65 insertions, 28 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 60182befd7..38029552d1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -174,7 +174,7 @@ class Column(val expr: Expression) extends Logging {
// NamedExpression under this Cast.
case c: Cast =>
c.transformUp {
- case Cast(ne: NamedExpression, to) => UnresolvedAlias(Cast(ne, to))
+ case c @ Cast(_: NamedExpression, _, _) => UnresolvedAlias(c)
} match {
case ne: NamedExpression => ne
case other => Alias(expr, usePrettyExpression(expr).sql)()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 5ee173f72e..391c34f128 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -18,6 +18,8 @@
package org.apache.spark.sql
import java.io.CharArrayWriter
+import java.sql.{Date, Timestamp}
+import java.util.TimeZone
import scala.collection.JavaConverters._
import scala.language.implicitConversions
@@ -43,7 +45,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection}
-import org.apache.spark.sql.catalyst.util.usePrettyExpression
+import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DateTimeUtils}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand, GlobalTempView, LocalTempView}
import org.apache.spark.sql.execution.datasources.LogicalRelation
@@ -250,6 +252,8 @@ class Dataset[T] private[sql](
val hasMoreData = takeResult.length > numRows
val data = takeResult.take(numRows)
+ lazy val timeZone = TimeZone.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone)
+
// For array values, replace Seq and Array with square brackets
// For cells that are beyond `truncate` characters, replace it with the
// first `truncate-3` and "..."
@@ -260,6 +264,10 @@ class Dataset[T] private[sql](
case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]")
case array: Array[_] => array.mkString("[", ", ", "]")
case seq: Seq[_] => seq.mkString("[", ", ", "]")
+ case d: Date =>
+ DateTimeUtils.dateToString(DateTimeUtils.fromJavaDate(d))
+ case ts: Timestamp =>
+ DateTimeUtils.timestampToString(DateTimeUtils.fromJavaTimestamp(ts), timeZone)
case _ => cell.toString
}
if (truncate > 0 && str.length > truncate) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
index 0384c0f236..d5a8566d07 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
@@ -369,7 +369,7 @@ class SQLBuilder private (
case ar: AttributeReference if groupByAttrMap.contains(ar) => groupByAttrMap(ar)
case a @ Cast(BitwiseAnd(
ShiftRight(ar: AttributeReference, Literal(value: Any, IntegerType)),
- Literal(1, IntegerType)), ByteType) if ar == gid =>
+ Literal(1, IntegerType)), ByteType, _) if ar == gid =>
// for converting an expression to its original SQL format grouping(col)
val idx = groupByExprs.length - 1 - value.asInstanceOf[Int]
groupByExprs.lift(idx).map(Grouping).getOrElse(a)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala
index 1b7fedca84..b8ac070e3a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.internal.SQLConf
@@ -104,7 +105,9 @@ case class OptimizeMetadataOnlyQuery(
val partAttrs = getPartitionAttrs(relation.catalogTable.partitionColumnNames, relation)
val partitionData = catalog.listPartitions(relation.catalogTable.identifier).map { p =>
InternalRow.fromSeq(partAttrs.map { attr =>
- Cast(Literal(p.spec(attr.name)), attr.dataType).eval()
+ // TODO: use correct timezone for partition values.
+ Cast(Literal(p.spec(attr.name)), attr.dataType,
+ Option(DateTimeUtils.defaultTimeZone().getID)).eval()
})
}
LocalRelation(partAttrs, partitionData)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index dcd9003ec6..9d046c0766 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -18,7 +18,8 @@
package org.apache.spark.sql.execution
import java.nio.charset.StandardCharsets
-import java.sql.Timestamp
+import java.sql.{Date, Timestamp}
+import java.util.TimeZone
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
@@ -139,22 +140,6 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType,
BooleanType, ByteType, ShortType, DateType, TimestampType, BinaryType)
- /** Implementation following Hive's TimestampWritable.toString */
- def formatTimestamp(timestamp: Timestamp): String = {
- val timestampString = timestamp.toString
- if (timestampString.length() > 19) {
- if (timestampString.length() == 21) {
- if (timestampString.substring(19).compareTo(".0") == 0) {
- return DateTimeUtils.threadLocalTimestampFormat.get().format(timestamp)
- }
- }
- return DateTimeUtils.threadLocalTimestampFormat.get().format(timestamp) +
- timestampString.substring(19)
- }
-
- return DateTimeUtils.threadLocalTimestampFormat.get().format(timestamp)
- }
-
def formatDecimal(d: java.math.BigDecimal): String = {
if (d.compareTo(java.math.BigDecimal.ZERO) == 0) {
java.math.BigDecimal.ZERO.toPlainString
@@ -195,8 +180,11 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType))
}.toSeq.sorted.mkString("{", ",", "}")
case (null, _) => "NULL"
- case (d: Int, DateType) => new java.util.Date(DateTimeUtils.daysToMillis(d)).toString
- case (t: Timestamp, TimestampType) => formatTimestamp(t)
+ case (d: Date, DateType) =>
+ DateTimeUtils.dateToString(DateTimeUtils.fromJavaDate(d))
+ case (t: Timestamp, TimestampType) =>
+ DateTimeUtils.timestampToString(DateTimeUtils.fromJavaTimestamp(t),
+ TimeZone.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone))
case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8)
case (decimal: java.math.BigDecimal, DecimalType()) => formatDecimal(decimal)
case (other, tpe) if primitiveTypes.contains(tpe) => other.toString
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 16c5193eda..be13cbc51a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, UnsafeKVExternalSorter}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.util.{SerializableConfiguration, Utils}
@@ -311,10 +312,11 @@ object FileFormatWriter extends Logging {
/** Expressions that given a partition key build a string like: col1=val/col2=val/... */
private def partitionStringExpression: Seq[Expression] = {
description.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
+ // TODO: use correct timezone for partition values.
val escaped = ScalaUDF(
ExternalCatalogUtils.escapePathName _,
StringType,
- Seq(Cast(c, StringType)),
+ Seq(Cast(c, StringType, Option(DateTimeUtils.defaultTimeZone().getID))),
Seq(StringType))
val str = If(IsNull(c), Literal(ExternalCatalogUtils.DEFAULT_PARTITION_NAME), escaped)
val partitionName = Literal(c.name + "=") :: str :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala
index fe9c6578b1..75f87a5503 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala
@@ -30,6 +30,7 @@ import org.apache.spark.metrics.source.HiveCatalogMetrics
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.SerializableConfiguration
@@ -135,9 +136,11 @@ abstract class PartitioningAwareFileIndex(
// we need to cast into the data type that user specified.
def castPartitionValuesToUserSchema(row: InternalRow) = {
InternalRow((0 until row.numFields).map { i =>
+ // TODO: use correct timezone for partition values.
Cast(
Literal.create(row.getUTF8String(i), StringType),
- userProvidedSchema.fields(i).dataType).eval()
+ userProvidedSchema.fields(i).dataType,
+ Option(DateTimeUtils.defaultTimeZone().getID)).eval()
}: _*)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 6ab6fa61dc..bd7cec3917 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -58,7 +58,7 @@ class IncrementalExecution(
*/
override lazy val optimizedPlan: LogicalPlan = {
sparkSession.sessionState.optimizer.execute(withCachedData) transformAllExpressions {
- case ts @ CurrentBatchTimestamp(timestamp, _) =>
+ case ts @ CurrentBatchTimestamp(timestamp, _, _) =>
logInfo(s"Current batch timestamp = $timestamp")
ts.toLiteral
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index a35950e2dc..ea3719421b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -502,7 +502,7 @@ class StreamExecution(
ct.dataType)
case cd: CurrentDate =>
CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs,
- cd.dataType)
+ cd.dataType, cd.timeZoneId)
}
reportTimeTaken("queryPlanning") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index d0c86ffc27..5ba4192512 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.internal
-import java.util.{NoSuchElementException, Properties}
+import java.util.{NoSuchElementException, Properties, TimeZone}
import java.util.concurrent.TimeUnit
import scala.collection.JavaConverters._
@@ -660,6 +660,12 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val SESSION_LOCAL_TIMEZONE =
+ SQLConfigBuilder("spark.sql.session.timeZone")
+ .doc("""The ID of session local timezone, e.g. "GMT", "America/Los_Angeles", etc.""")
+ .stringConf
+ .createWithDefault(TimeZone.getDefault().getID())
+
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
@@ -858,6 +864,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
override def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED)
+ override def sessionLocalTimeZone: String = getConf(SQLConf.SESSION_LOCAL_TIMEZONE)
+
def ndvMaxError: Double = getConf(NDV_MAX_ERROR)
override def cboEnabled: Boolean = getConf(SQLConf.CBO_ENABLED)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index cb7b97906a..6a190b98ea 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -29,6 +29,7 @@ import org.scalatest.Matchers._
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Project, Union}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.{FilterExec, QueryExecution}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange}
@@ -869,6 +870,30 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
assert(testData.select($"*").filter($"key" < 0).showString(1) === expectedAnswer)
}
+ test("SPARK-18350 show with session local timezone") {
+ val d = Date.valueOf("2016-12-01")
+ val ts = Timestamp.valueOf("2016-12-01 00:00:00")
+ val df = Seq((d, ts)).toDF("d", "ts")
+ val expectedAnswer = """+----------+-------------------+
+ ||d |ts |
+ |+----------+-------------------+
+ ||2016-12-01|2016-12-01 00:00:00|
+ |+----------+-------------------+
+ |""".stripMargin
+ assert(df.showString(1, truncate = 0) === expectedAnswer)
+
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") {
+
+ val expectedAnswer = """+----------+-------------------+
+ ||d |ts |
+ |+----------+-------------------+
+ ||2016-12-01|2016-12-01 08:00:00|
+ |+----------+-------------------+
+ |""".stripMargin
+ assert(df.showString(1, truncate = 0) === expectedAnswer)
+ }
+ }
+
test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") {
val rowRDD = sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0))))
val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false)))