aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@databricks.com>2017-04-21 10:06:12 +0800
committerWenchen Fan <wenchen@databricks.com>2017-04-21 10:06:12 +0800
commit760c8d088df1d35d7b8942177d47bc1677daf143 (patch)
tree86e975adf8963dd5455a2f9c11607a9974c21373
parent0368eb9d86634c83b3140ce3190cb9e0d0b7fd86 (diff)
downloadspark-760c8d088df1d35d7b8942177d47bc1677daf143.tar.gz
spark-760c8d088df1d35d7b8942177d47bc1677daf143.tar.bz2
spark-760c8d088df1d35d7b8942177d47bc1677daf143.zip
[SPARK-20329][SQL] Make timezone aware expression without timezone unresolved
## What changes were proposed in this pull request? A cast expression with a resolved time zone is not equal to a cast expression without a resolved time zone. The `ResolveAggregateFunction` assumed that these expression were the same, and would fail to resolve `HAVING` clauses which contain a `Cast` expression. This is in essence caused by the fact that a `TimeZoneAwareExpression` can be resolved without a set time zone. This PR fixes this, and makes a `TimeZoneAwareExpression` unresolved as long as it has no TimeZone set. ## How was this patch tested? Added a regression test to the `SQLQueryTestSuite.having` file. Author: Herman van Hovell <hvanhovell@databricks.com> Closes #17641 from hvanhovell/SPARK-20329.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala20
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala61
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala10
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala35
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala2
-rw-r--r--sql/core/src/test/resources/sql-tests/inputs/having.sql3
-rw-r--r--sql/core/src/test/resources/sql-tests/results/having.sql.out11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala16
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala2
19 files changed, 148 insertions, 78 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index eafeb4ac1a..dcadbbc90f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -150,6 +150,7 @@ class Analyzer(
ResolveAggregateFunctions ::
TimeWindowing ::
ResolveInlineTables(conf) ::
+ ResolveTimeZone(conf) ::
TypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
@@ -161,8 +162,6 @@ class Analyzer(
HandleNullInputsForUDF),
Batch("FixNullability", Once,
FixNullability),
- Batch("ResolveTimeZone", Once,
- ResolveTimeZone),
Batch("Subquery", Once,
UpdateOuterReferences),
Batch("Cleanup", fixedPoint,
@@ -2368,23 +2367,6 @@ class Analyzer(
}
}
}
-
- /**
- * Replace [[TimeZoneAwareExpression]] without timezone id by its copy with session local
- * time zone.
- */
- object ResolveTimeZone extends Rule[LogicalPlan] {
-
- override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions {
- case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty =>
- e.withTimeZone(conf.sessionLocalTimeZone)
- // Casts could be added in the subquery plan through the rule TypeCoercion while coercing
- // the types between the value expression and list query expression of IN expression.
- // We need to subject the subquery plan through ResolveTimeZone again to setup timezone
- // information for time zone aware expressions.
- case e: ListQuery => e.withNewPlan(apply(e.plan))
- }
- }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
index a991dd96e2..f2df3e1326 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.analysis
import scala.util.control.NonFatal
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Cast, TimeZoneAwareExpression}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
@@ -29,7 +28,7 @@ import org.apache.spark.sql.types.{StructField, StructType}
/**
* An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]].
*/
-case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] {
+case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case table: UnresolvedInlineTable if table.expressionsResolved =>
validateInputDimension(table)
@@ -99,12 +98,9 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] {
val castedExpr = if (e.dataType.sameType(targetType)) {
e
} else {
- Cast(e, targetType)
+ cast(e, targetType)
}
- castedExpr.transform {
- case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty =>
- e.withTimeZone(conf.sessionLocalTimeZone)
- }.eval()
+ castedExpr.eval()
} catch {
case NonFatal(ex) =>
table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala
new file mode 100644
index 0000000000..a27aa845bf
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ListQuery, TimeZoneAwareExpression}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.DataType
+
+/**
+ * Replace [[TimeZoneAwareExpression]] without timezone id by its copy with session local
+ * time zone.
+ */
+case class ResolveTimeZone(conf: SQLConf) extends Rule[LogicalPlan] {
+ private val transformTimeZoneExprs: PartialFunction[Expression, Expression] = {
+ case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty =>
+ e.withTimeZone(conf.sessionLocalTimeZone)
+ // Casts could be added in the subquery plan through the rule TypeCoercion while coercing
+ // the types between the value expression and list query expression of IN expression.
+ // We need to subject the subquery plan through ResolveTimeZone again to setup timezone
+ // information for time zone aware expressions.
+ case e: ListQuery => e.withNewPlan(apply(e.plan))
+ }
+
+ override def apply(plan: LogicalPlan): LogicalPlan =
+ plan.resolveExpressions(transformTimeZoneExprs)
+
+ def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs)
+}
+
+/**
+ * Mix-in trait for constructing valid [[Cast]] expressions.
+ */
+trait CastSupport {
+ /**
+ * Configuration used to create a valid cast expression.
+ */
+ def conf: SQLConf
+
+ /**
+ * Create a Cast expression with the session local time zone.
+ */
+ def cast(child: Expression, dataType: DataType): Cast = {
+ Cast(child, dataType, Option(conf.sessionLocalTimeZone))
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala
index 3bd54c257d..ea46dd7282 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala
@@ -47,7 +47,7 @@ import org.apache.spark.sql.internal.SQLConf
* This should be only done after the batch of Resolution, because the view attributes are not
* completely resolved during the batch of Resolution.
*/
-case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] {
+case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case v @ View(desc, output, child) if child.resolved && output != child.output =>
val resolver = conf.resolver
@@ -78,7 +78,7 @@ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] {
throw new AnalysisException(s"Cannot up cast ${originAttr.sql} from " +
s"${originAttr.dataType.simpleString} to ${attr.simpleString} as it may truncate\n")
} else {
- Alias(Cast(originAttr, attr.dataType), attr.name)(exprId = attr.exprId,
+ Alias(cast(originAttr, attr.dataType), attr.name)(exprId = attr.exprId,
qualifier = attr.qualifier, explicitMetadata = Some(attr.metadata))
}
case (_, originAttr) => originAttr
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index f8fe774823..bb8fd5032d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -24,7 +24,6 @@ import java.util.{Calendar, TimeZone}
import scala.util.control.NonFatal
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
@@ -34,6 +33,9 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
* Common base class for time zone aware expressions.
*/
trait TimeZoneAwareExpression extends Expression {
+ /** The expression is only resolved when the time zone has been set. */
+ override lazy val resolved: Boolean =
+ childrenResolved && checkInputDataTypes().isSuccess && timeZoneId.isDefined
/** the timezone ID to be used to evaluate value. */
def timeZoneId: Option[String]
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala
index f45a826869..d0fe815052 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala
@@ -22,6 +22,7 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.{LongType, NullType, TimestampType}
/**
@@ -91,12 +92,13 @@ class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter {
test("convert TimeZoneAwareExpression") {
val table = UnresolvedInlineTable(Seq("c1"),
Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType))))
- val converted = ResolveInlineTables(conf).convert(table)
+ val withTimeZone = ResolveTimeZone(conf).apply(table)
+ val LocalRelation(output, data) = ResolveInlineTables(conf).apply(withTimeZone)
val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType)
.withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long]
- assert(converted.output.map(_.dataType) == Seq(TimestampType))
- assert(converted.data.size == 1)
- assert(converted.data(0).getLong(0) == correct)
+ assert(output.map(_.dataType) == Seq(TimestampType))
+ assert(data.size == 1)
+ assert(data.head.getLong(0) == correct)
}
test("nullability inference in convert") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index 011d09ff60..2624f5586f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
@@ -787,6 +788,12 @@ class TypeCoercionSuite extends PlanTest {
}
}
+ private val timeZoneResolver = ResolveTimeZone(new SQLConf)
+
+ private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = {
+ timeZoneResolver(TypeCoercion.WidenSetOperationTypes(plan))
+ }
+
test("WidenSetOperationTypes for except and intersect") {
val firstTable = LocalRelation(
AttributeReference("i", IntegerType)(),
@@ -799,11 +806,10 @@ class TypeCoercionSuite extends PlanTest {
AttributeReference("f", FloatType)(),
AttributeReference("l", LongType)())
- val wt = TypeCoercion.WidenSetOperationTypes
val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)
- val r1 = wt(Except(firstTable, secondTable)).asInstanceOf[Except]
- val r2 = wt(Intersect(firstTable, secondTable)).asInstanceOf[Intersect]
+ val r1 = widenSetOperationTypes(Except(firstTable, secondTable)).asInstanceOf[Except]
+ val r2 = widenSetOperationTypes(Intersect(firstTable, secondTable)).asInstanceOf[Intersect]
checkOutput(r1.left, expectedTypes)
checkOutput(r1.right, expectedTypes)
checkOutput(r2.left, expectedTypes)
@@ -838,10 +844,9 @@ class TypeCoercionSuite extends PlanTest {
AttributeReference("p", ByteType)(),
AttributeReference("q", DoubleType)())
- val wt = TypeCoercion.WidenSetOperationTypes
val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)
- val unionRelation = wt(
+ val unionRelation = widenSetOperationTypes(
Union(firstTable :: secondTable :: thirdTable :: forthTable :: Nil)).asInstanceOf[Union]
assert(unionRelation.children.length == 4)
checkOutput(unionRelation.children.head, expectedTypes)
@@ -862,17 +867,15 @@ class TypeCoercionSuite extends PlanTest {
}
}
- val dp = TypeCoercion.WidenSetOperationTypes
-
val left1 = LocalRelation(
AttributeReference("l", DecimalType(10, 8))())
val right1 = LocalRelation(
AttributeReference("r", DecimalType(5, 5))())
val expectedType1 = Seq(DecimalType(10, 8))
- val r1 = dp(Union(left1, right1)).asInstanceOf[Union]
- val r2 = dp(Except(left1, right1)).asInstanceOf[Except]
- val r3 = dp(Intersect(left1, right1)).asInstanceOf[Intersect]
+ val r1 = widenSetOperationTypes(Union(left1, right1)).asInstanceOf[Union]
+ val r2 = widenSetOperationTypes(Except(left1, right1)).asInstanceOf[Except]
+ val r3 = widenSetOperationTypes(Intersect(left1, right1)).asInstanceOf[Intersect]
checkOutput(r1.children.head, expectedType1)
checkOutput(r1.children.last, expectedType1)
@@ -891,17 +894,17 @@ class TypeCoercionSuite extends PlanTest {
val plan2 = LocalRelation(
AttributeReference("r", rType)())
- val r1 = dp(Union(plan1, plan2)).asInstanceOf[Union]
- val r2 = dp(Except(plan1, plan2)).asInstanceOf[Except]
- val r3 = dp(Intersect(plan1, plan2)).asInstanceOf[Intersect]
+ val r1 = widenSetOperationTypes(Union(plan1, plan2)).asInstanceOf[Union]
+ val r2 = widenSetOperationTypes(Except(plan1, plan2)).asInstanceOf[Except]
+ val r3 = widenSetOperationTypes(Intersect(plan1, plan2)).asInstanceOf[Intersect]
checkOutput(r1.children.last, Seq(expectedType))
checkOutput(r2.right, Seq(expectedType))
checkOutput(r3.right, Seq(expectedType))
- val r4 = dp(Union(plan2, plan1)).asInstanceOf[Union]
- val r5 = dp(Except(plan2, plan1)).asInstanceOf[Except]
- val r6 = dp(Intersect(plan2, plan1)).asInstanceOf[Intersect]
+ val r4 = widenSetOperationTypes(Union(plan2, plan1)).asInstanceOf[Union]
+ val r5 = widenSetOperationTypes(Except(plan2, plan1)).asInstanceOf[Except]
+ val r6 = widenSetOperationTypes(Intersect(plan2, plan1)).asInstanceOf[Intersect]
checkOutput(r4.children.last, Seq(expectedType))
checkOutput(r5.left, Seq(expectedType))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index a7ffa884d2..22f3f3514f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -34,7 +34,7 @@ import org.apache.spark.unsafe.types.UTF8String
*/
class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
- private def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): Cast = {
+ private def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = Some("GMT")): Cast = {
v match {
case lit: Expression => Cast(lit, targetType, timeZoneId)
case _ => Cast(Literal(v), targetType, timeZoneId)
@@ -47,7 +47,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
}
private def checkNullCast(from: DataType, to: DataType): Unit = {
- checkEvaluation(cast(Literal.create(null, from), to, Option("GMT")), null)
+ checkEvaluation(cast(Literal.create(null, from), to), null)
}
test("null cast") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
index 9978f35a03..ca89bf7db0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
@@ -160,7 +160,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("Seconds") {
assert(Second(Literal.create(null, DateType), gmtId).resolved === false)
- assert(Second(Cast(Literal(d), TimestampType), None).resolved === true)
+ assert(Second(Cast(Literal(d), TimestampType, gmtId), gmtId).resolved === true)
checkEvaluation(Second(Cast(Literal(d), TimestampType, gmtId), gmtId), 0)
checkEvaluation(Second(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 15)
checkEvaluation(Second(Literal(ts), gmtId), 15)
@@ -220,7 +220,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("Hour") {
assert(Hour(Literal.create(null, DateType), gmtId).resolved === false)
- assert(Hour(Literal(ts), None).resolved === true)
+ assert(Hour(Literal(ts), gmtId).resolved === true)
checkEvaluation(Hour(Cast(Literal(d), TimestampType, gmtId), gmtId), 0)
checkEvaluation(Hour(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 13)
checkEvaluation(Hour(Literal(ts), gmtId), 13)
@@ -246,7 +246,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("Minute") {
assert(Minute(Literal.create(null, DateType), gmtId).resolved === false)
- assert(Minute(Literal(ts), None).resolved === true)
+ assert(Minute(Literal(ts), gmtId).resolved === true)
checkEvaluation(Minute(Cast(Literal(d), TimestampType, gmtId), gmtId), 0)
checkEvaluation(
Minute(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 10)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index 1ba6dd1c5e..b6399edb68 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -25,10 +25,12 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
+import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -45,7 +47,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
protected def checkEvaluation(
expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = {
val serializer = new JavaSerializer(new SparkConf()).newInstance
- val expr: Expression = serializer.deserialize(serializer.serialize(expression))
+ val resolver = ResolveTimeZone(new SQLConf)
+ val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression)))
val catalystValue = CatalystTypeConverters.convertToCatalyst(expected)
checkEvaluationWithoutCodegen(expr, catalystValue, inputRow)
checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
index 6566502bd8..4e718d609c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
@@ -36,7 +36,7 @@ class SparkPlanner(
experimentalMethods.extraStrategies ++
extraPlanningStrategies ++ (
FileSourceStrategy ::
- DataSourceStrategy ::
+ DataSourceStrategy(conf) ::
SpecialLimits ::
Aggregation ::
JoinSelection ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 2d83d512e7..d307122b5c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QualifiedTableName, TableIdentifier}
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QualifiedTableName}
import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogUtils}
@@ -48,7 +48,7 @@ import org.apache.spark.unsafe.types.UTF8String
* Note that, this rule must be run after `PreprocessTableCreation` and
* `PreprocessTableInsertion`.
*/
-case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] {
+case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport {
def resolver: Resolver = conf.resolver
@@ -98,11 +98,11 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] {
val potentialSpecs = staticPartitions.filter {
case (partKey, partValue) => resolver(field.name, partKey)
}
- if (potentialSpecs.size == 0) {
+ if (potentialSpecs.isEmpty) {
None
} else if (potentialSpecs.size == 1) {
val partValue = potentialSpecs.head._2
- Some(Alias(Cast(Literal(partValue), field.dataType), field.name)())
+ Some(Alias(cast(Literal(partValue), field.dataType), field.name)())
} else {
throw new AnalysisException(
s"Partition column ${field.name} have multiple values specified, " +
@@ -258,7 +258,9 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
/**
* A Strategy for planning scans over data sources defined using the sources API.
*/
-object DataSourceStrategy extends Strategy with Logging {
+case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with CastSupport {
+ import DataSourceStrategy._
+
def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match {
case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _)) =>
pruneFilterProjectRaw(
@@ -298,7 +300,7 @@ object DataSourceStrategy extends Strategy with Logging {
// Restriction: Bucket pruning works iff the bucketing column has one and only one column.
def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = {
val mutableRow = new SpecificInternalRow(Seq(bucketColumn.dataType))
- mutableRow(0) = Cast(Literal(value), bucketColumn.dataType).eval(null)
+ mutableRow(0) = cast(Literal(value), bucketColumn.dataType).eval(null)
val bucketIdGeneration = UnsafeProjection.create(
HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil,
bucketColumn :: Nil)
@@ -436,7 +438,9 @@ object DataSourceStrategy extends Strategy with Logging {
private[this] def toCatalystRDD(relation: LogicalRelation, rdd: RDD[Row]): RDD[InternalRow] = {
toCatalystRDD(relation, relation.output, rdd)
}
+}
+object DataSourceStrategy {
/**
* Tries to translate a Catalyst [[Expression]] into data source [[Filter]].
*
@@ -527,8 +531,8 @@ object DataSourceStrategy extends Strategy with Logging {
* all [[Filter]]s that are completely filtered at the DataSource.
*/
protected[sql] def selectFilters(
- relation: BaseRelation,
- predicates: Seq[Expression]): (Seq[Expression], Seq[Filter], Set[Filter]) = {
+ relation: BaseRelation,
+ predicates: Seq[Expression]): (Seq[Expression], Seq[Filter], Set[Filter]) = {
// For conciseness, all Catalyst filter expressions of type `expressions.Expression` below are
// called `predicate`s, while all data source filters of type `sources.Filter` are simply called
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index 7abf2ae516..3f4a78580f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -22,7 +22,7 @@ import java.util.Locale
import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog._
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, RowOrdering}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.command.DDLUtils
@@ -315,7 +315,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi
* table. It also does data type casting and field renaming, to make sure that the columns to be
* inserted have the correct data type and fields have the correct names.
*/
-case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] {
+case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport {
private def preprocess(
insert: InsertIntoTable,
tblName: String,
@@ -367,7 +367,7 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] {
// Renaming is needed for handling the following cases like
// 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2
// 2) Target tables have column metadata
- Alias(Cast(actual, expected.dataType), expected.name)(
+ Alias(cast(actual, expected.dataType), expected.name)(
explicitMetadata = Option(expected.metadata))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 2b14eca919..df7c3678b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.internal
import org.apache.spark.SparkConf
import org.apache.spark.annotation.{Experimental, InterfaceStability}
import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy, UDFRegistration}
-import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry}
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry, ResolveTimeZone}
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.parser.ParserInterface
diff --git a/sql/core/src/test/resources/sql-tests/inputs/having.sql b/sql/core/src/test/resources/sql-tests/inputs/having.sql
index 364c022d95..868a911e78 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/having.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/having.sql
@@ -13,3 +13,6 @@ SELECT count(k) FROM hav GROUP BY v + 1 HAVING v + 1 = 2;
-- SPARK-11032: resolve having correctly
SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0);
+
+-- SPARK-20329: make sure we handle timezones correctly
+SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1;
diff --git a/sql/core/src/test/resources/sql-tests/results/having.sql.out b/sql/core/src/test/resources/sql-tests/results/having.sql.out
index e092383267..d87ee52216 100644
--- a/sql/core/src/test/resources/sql-tests/results/having.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/having.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 4
+-- Number of queries: 5
-- !query 0
@@ -38,3 +38,12 @@ SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0)
struct<min(v):int>
-- !query 3 output
1
+
+
+-- !query 4
+SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1
+-- !query 4 schema
+struct<(a + CAST(b AS BIGINT)):bigint>
+-- !query 4 output
+3
+7
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 9b65419dba..ba0ca666b5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -90,6 +90,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
originalDataFrame: DataFrame): Unit = {
// This test verifies parts of the plan. Disable whole stage codegen.
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
+ val strategy = DataSourceStrategy(spark.sessionState.conf)
val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k")
val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec
// Limit: bucket pruning only works when the bucket column has one and only one column
@@ -98,7 +99,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex)
val matchedBuckets = new BitSet(numBuckets)
bucketValues.foreach { value =>
- matchedBuckets.set(DataSourceStrategy.getBucketId(bucketColumn, numBuckets, value))
+ matchedBuckets.set(strategy.getBucketId(bucketColumn, numBuckets, value))
}
// Filter could hide the bug in bucket pruning. Thus, skipping all the filters
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala
index b16c9f8fc9..735e07c213 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Expression, Literal}
import org.apache.spark.sql.execution.datasources.DataSourceAnalysis
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.sql.types.{DataType, IntegerType, StructType}
class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll {
@@ -49,7 +49,11 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll {
}
Seq(true, false).foreach { caseSensitive =>
- val rule = DataSourceAnalysis(new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive))
+ val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive)
+ def cast(e: Expression, dt: DataType): Expression = {
+ Cast(e, dt, Option(conf.sessionLocalTimeZone))
+ }
+ val rule = DataSourceAnalysis(conf)
test(
s"convertStaticPartitions only handle INSERT having at least static partitions " +
s"(caseSensitive: $caseSensitive)") {
@@ -150,7 +154,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll {
if (!caseSensitive) {
val nonPartitionedAttributes = Seq('e.int, 'f.int)
val expected = nonPartitionedAttributes ++
- Seq(Cast(Literal("1"), IntegerType), Cast(Literal("3"), IntegerType))
+ Seq(cast(Literal("1"), IntegerType), cast(Literal("3"), IntegerType))
val actual = rule.convertStaticPartitions(
sourceAttributes = nonPartitionedAttributes,
providedPartitions = Map("b" -> Some("1"), "C" -> Some("3")),
@@ -162,7 +166,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll {
{
val nonPartitionedAttributes = Seq('e.int, 'f.int)
val expected = nonPartitionedAttributes ++
- Seq(Cast(Literal("1"), IntegerType), Cast(Literal("3"), IntegerType))
+ Seq(cast(Literal("1"), IntegerType), cast(Literal("3"), IntegerType))
val actual = rule.convertStaticPartitions(
sourceAttributes = nonPartitionedAttributes,
providedPartitions = Map("b" -> Some("1"), "c" -> Some("3")),
@@ -174,7 +178,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll {
// Test the case having a single static partition column.
{
val nonPartitionedAttributes = Seq('e.int, 'f.int)
- val expected = nonPartitionedAttributes ++ Seq(Cast(Literal("1"), IntegerType))
+ val expected = nonPartitionedAttributes ++ Seq(cast(Literal("1"), IntegerType))
val actual = rule.convertStaticPartitions(
sourceAttributes = nonPartitionedAttributes,
providedPartitions = Map("b" -> Some("1")),
@@ -189,7 +193,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll {
val dynamicPartitionAttributes = Seq('g.int)
val expected =
nonPartitionedAttributes ++
- Seq(Cast(Literal("1"), IntegerType)) ++
+ Seq(cast(Literal("1"), IntegerType)) ++
dynamicPartitionAttributes
val actual = rule.convertStaticPartitions(
sourceAttributes = nonPartitionedAttributes ++ dynamicPartitionAttributes,
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
index 9d3b31f39c..e16c9e46b7 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
@@ -101,7 +101,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session
experimentalMethods.extraStrategies ++
extraPlanningStrategies ++ Seq(
FileSourceStrategy,
- DataSourceStrategy,
+ DataSourceStrategy(conf),
SpecialLimits,
InMemoryScans,
HiveTableScans,