diff options
3 files changed, 8 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 168a4e30ea..fe0d3f2997 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -251,10 +251,10 @@ trait HiveTypeCoercion { p.makeCopy(Array(Cast(p.left, StringType), p.right)) case p: BinaryComparison if p.left.dataType == StringType && p.right.dataType == TimestampType => - p.makeCopy(Array(p.left, Cast(p.right, StringType))) + p.makeCopy(Array(Cast(p.left, TimestampType), p.right)) case p: BinaryComparison if p.left.dataType == TimestampType && p.right.dataType == StringType => - p.makeCopy(Array(Cast(p.left, StringType), p.right)) + p.makeCopy(Array(p.left, Cast(p.right, TimestampType))) case p: BinaryComparison if p.left.dataType == TimestampType && p.right.dataType == DateType => p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType))) @@ -274,7 +274,7 @@ trait HiveTypeCoercion { i.makeCopy(Array(Cast(a, StringType), b)) case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) => - i.makeCopy(Array(Cast(a, StringType), b)) + i.makeCopy(Array(a, b.map(Cast(_, TimestampType)))) case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == TimestampType) => i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 8cdbe076cb..479ad9fe62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -298,6 +298,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-3173 Timestamp support in the parser") { checkAnswer(sql( + "SELECT time FROM timestamps WHERE time='1969-12-31 16:00:00.0'"), + Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00"))) + + checkAnswer(sql( "SELECT time FROM timestamps WHERE time=CAST('1969-12-31 16:00:00.001' AS TIMESTAMP)"), Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001"))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 446771ab2a..8fbc2d23d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -175,7 +175,7 @@ object TestData { "4, D4, true, 2147483644" :: Nil) case class TimestampField(time: Timestamp) - val timestamps = TestSQLContext.sparkContext.parallelize((1 to 3).map { i => + val timestamps = TestSQLContext.sparkContext.parallelize((0 to 3).map { i => TimestampField(new Timestamp(i)) }) timestamps.toDF().registerTempTable("timestamps") |