aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2017-03-03 07:14:37 -0800
committerHerman van Hovell <hvanhovell@databricks.com>2017-03-03 07:14:37 -0800
commit98bcc188f98e44c1675d8b3a28f44f4f900abc43 (patch)
treef7340f325fd31749c542ac32974adcd801cfccca /sql/catalyst
parent776fac3988271a1e4128cb31f21e5f7f3b7bcf0e (diff)
downloadspark-98bcc188f98e44c1675d8b3a28f44f4f900abc43.tar.gz
spark-98bcc188f98e44c1675d8b3a28f44f4f900abc43.tar.bz2
spark-98bcc188f98e44c1675d8b3a28f44f4f900abc43.zip
[SPARK-19758][SQL] Resolving timezone aware expressions with time zone when resolving inline table
## What changes were proposed in this pull request? When we resolve inline tables in analyzer, we will evaluate the expressions of inline tables. When it evaluates a `TimeZoneAwareExpression` expression, an error will happen because the `TimeZoneAwareExpression` is not associated with timezone yet. So we need to resolve these `TimeZoneAwareExpression`s with time zone when resolving inline tables. ## How was this patch tested? Jenkins tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #17114 from viirya/resolve-timeawareexpr-inline-table.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala16
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala40
3 files changed, 36 insertions, 22 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index c477cb48d0..6d569b612d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -146,7 +146,7 @@ class Analyzer(
GlobalAggregates ::
ResolveAggregateFunctions ::
TimeWindowing ::
- ResolveInlineTables ::
+ ResolveInlineTables(conf) ::
TypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
index 7323197b10..d5b3ea8c37 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
@@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.analysis
import scala.util.control.NonFatal
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Cast
+import org.apache.spark.sql.catalyst.{CatalystConf, InternalRow}
+import org.apache.spark.sql.catalyst.expressions.{Cast, TimeZoneAwareExpression}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.{StructField, StructType}
@@ -28,7 +28,7 @@ import org.apache.spark.sql.types.{StructField, StructType}
/**
* An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]].
*/
-object ResolveInlineTables extends Rule[LogicalPlan] {
+case class ResolveInlineTables(conf: CatalystConf) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case table: UnresolvedInlineTable if table.expressionsResolved =>
validateInputDimension(table)
@@ -95,11 +95,15 @@ object ResolveInlineTables extends Rule[LogicalPlan] {
InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) =>
val targetType = fields(ci).dataType
try {
- if (e.dataType.sameType(targetType)) {
- e.eval()
+ val castedExpr = if (e.dataType.sameType(targetType)) {
+ e
} else {
- Cast(e, targetType).eval()
+ Cast(e, targetType)
}
+ castedExpr.transform {
+ case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty =>
+ e.withTimeZone(conf.sessionLocalTimeZone)
+ }.eval()
} catch {
case NonFatal(ex) =>
table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala
index 920c6ea50f..f45a826869 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala
@@ -20,68 +20,67 @@ package org.apache.spark.sql.catalyst.analysis
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.expressions.{Literal, Rand}
+import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
-import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.types.{LongType, NullType}
+import org.apache.spark.sql.types.{LongType, NullType, TimestampType}
/**
* Unit tests for [[ResolveInlineTables]]. Note that there are also test cases defined in
* end-to-end tests (in sql/core module) for verifying the correct error messages are shown
* in negative cases.
*/
-class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter {
+class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter {
private def lit(v: Any): Literal = Literal(v)
test("validate inputs are foldable") {
- ResolveInlineTables.validateInputEvaluable(
+ ResolveInlineTables(conf).validateInputEvaluable(
UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)))))
// nondeterministic (rand) should not work
intercept[AnalysisException] {
- ResolveInlineTables.validateInputEvaluable(
+ ResolveInlineTables(conf).validateInputEvaluable(
UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1)))))
}
// aggregate should not work
intercept[AnalysisException] {
- ResolveInlineTables.validateInputEvaluable(
+ ResolveInlineTables(conf).validateInputEvaluable(
UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1))))))
}
// unresolved attribute should not work
intercept[AnalysisException] {
- ResolveInlineTables.validateInputEvaluable(
+ ResolveInlineTables(conf).validateInputEvaluable(
UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A")))))
}
}
test("validate input dimensions") {
- ResolveInlineTables.validateInputDimension(
+ ResolveInlineTables(conf).validateInputDimension(
UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2)))))
// num alias != data dimension
intercept[AnalysisException] {
- ResolveInlineTables.validateInputDimension(
+ ResolveInlineTables(conf).validateInputDimension(
UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2)))))
}
// num alias == data dimension, but data themselves are inconsistent
intercept[AnalysisException] {
- ResolveInlineTables.validateInputDimension(
+ ResolveInlineTables(conf).validateInputDimension(
UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22)))))
}
}
test("do not fire the rule if not all expressions are resolved") {
val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A"))))
- assert(ResolveInlineTables(table) == table)
+ assert(ResolveInlineTables(conf)(table) == table)
}
test("convert") {
val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
- val converted = ResolveInlineTables.convert(table)
+ val converted = ResolveInlineTables(conf).convert(table)
assert(converted.output.map(_.dataType) == Seq(LongType))
assert(converted.data.size == 2)
@@ -89,13 +88,24 @@ class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter {
assert(converted.data(1).getLong(0) == 2L)
}
+ test("convert TimeZoneAwareExpression") {
+ val table = UnresolvedInlineTable(Seq("c1"),
+ Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType))))
+ val converted = ResolveInlineTables(conf).convert(table)
+ val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType)
+ .withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long]
+ assert(converted.output.map(_.dataType) == Seq(TimestampType))
+ assert(converted.data.size == 1)
+ assert(converted.data(0).getLong(0) == correct)
+ }
+
test("nullability inference in convert") {
val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
- val converted1 = ResolveInlineTables.convert(table1)
+ val converted1 = ResolveInlineTables(conf).convert(table1)
assert(!converted1.schema.fields(0).nullable)
val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType))))
- val converted2 = ResolveInlineTables.convert(table2)
+ val converted2 = ResolveInlineTables(conf).convert(table2)
assert(converted2.schema.fields(0).nullable)
}
}