From 0cb7662d8683c913c4fff02e8fb0ec75261d9731 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 28 Oct 2015 21:35:57 -0700 Subject: [SPARK-11351] [SQL] support hive interval literal Author: Wenchen Fan Closes #9304 from cloud-fan/interval. --- .../org/apache/spark/sql/catalyst/SqlParser.scala | 71 ++++++++++++++++------ .../apache/spark/sql/catalyst/SqlParserSuite.scala | 52 ++++++++++++++++ 2 files changed, 103 insertions(+), 20 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 833368b7d5..0fef043027 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -322,7 +322,7 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val literal: Parser[Literal] = ( numericLiteral | booleanLiteral - | stringLit ^^ {case s => Literal.create(s, StringType) } + | stringLit ^^ { case s => Literal.create(s, StringType) } | intervalLiteral | NULL ^^^ Literal.create(null, NullType) ) @@ -349,13 +349,12 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val integral: Parser[String] = sign.? ~ numericLit ^^ { case s ~ n => s.getOrElse("") + n } - private def intervalUnit(unitName: String) = - acceptIf { - case lexical.Identifier(str) => - val normalized = lexical.normalizeKeyword(str) - normalized == unitName || normalized == unitName + "s" - case _ => false - } {_ => "wrong interval unit"} + private def intervalUnit(unitName: String) = acceptIf { + case lexical.Identifier(str) => + val normalized = lexical.normalizeKeyword(str) + normalized == unitName || normalized == unitName + "s" + case _ => false + } {_ => "wrong interval unit"} protected lazy val month: Parser[Int] = integral <~ intervalUnit("month") ^^ { case num => num.toInt } @@ -396,21 +395,53 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { case num => num.toLong * CalendarInterval.MICROS_PER_WEEK } + private def intervalKeyword(keyword: String) = acceptIf { + case lexical.Identifier(str) => + lexical.normalizeKeyword(str) == keyword + case _ => false + } {_ => "wrong interval keyword"} + protected lazy val intervalLiteral: Parser[Literal] = - INTERVAL ~> year.? ~ month.? ~ week.? ~ day.? ~ hour.? ~ minute.? ~ second.? ~ - millisecond.? ~ microsecond.? ^^ { - case year ~ month ~ week ~ day ~ hour ~ minute ~ second ~ + ( INTERVAL ~> stringLit <~ intervalKeyword("year") ~ intervalKeyword("to") ~ + intervalKeyword("month") ^^ { case s => + Literal(CalendarInterval.fromYearMonthString(s)) + } + | INTERVAL ~> stringLit <~ intervalKeyword("day") ~ intervalKeyword("to") ~ + intervalKeyword("second") ^^ { case s => + Literal(CalendarInterval.fromDayTimeString(s)) + } + | INTERVAL ~> stringLit <~ intervalKeyword("year") ^^ { case s => + Literal(CalendarInterval.fromSingleUnitString("year", s)) + } + | INTERVAL ~> stringLit <~ intervalKeyword("month") ^^ { case s => + Literal(CalendarInterval.fromSingleUnitString("month", s)) + } + | INTERVAL ~> stringLit <~ intervalKeyword("day") ^^ { case s => + Literal(CalendarInterval.fromSingleUnitString("day", s)) + } + | INTERVAL ~> stringLit <~ intervalKeyword("hour") ^^ { case s => + Literal(CalendarInterval.fromSingleUnitString("hour", s)) + } + | INTERVAL ~> stringLit <~ intervalKeyword("minute") ^^ { case s => + Literal(CalendarInterval.fromSingleUnitString("minute", s)) + } + | INTERVAL ~> stringLit <~ intervalKeyword("second") ^^ { case s => + Literal(CalendarInterval.fromSingleUnitString("second", s)) + } + | INTERVAL ~> year.? ~ month.? ~ week.? ~ day.? ~ hour.? ~ minute.? ~ second.? ~ + millisecond.? ~ microsecond.? ^^ { case year ~ month ~ week ~ day ~ hour ~ minute ~ second ~ millisecond ~ microsecond => - if (!Seq(year, month, week, day, hour, minute, second, - millisecond, microsecond).exists(_.isDefined)) { - throw new AnalysisException( - "at least one time unit should be given for interval literal") - } - val months = Seq(year, month).map(_.getOrElse(0)).sum - val microseconds = Seq(week, day, hour, minute, second, millisecond, microsecond) - .map(_.getOrElse(0L)).sum - Literal.create(new CalendarInterval(months, microseconds), CalendarIntervalType) + if (!Seq(year, month, week, day, hour, minute, second, + millisecond, microsecond).exists(_.isDefined)) { + throw new AnalysisException( + "at least one time unit should be given for interval literal") } + val months = Seq(year, month).map(_.getOrElse(0)).sum + val microseconds = Seq(week, day, hour, minute, second, millisecond, microsecond) + .map(_.getOrElse(0L)).sum + Literal(new CalendarInterval(months, microseconds)) + } + ) private def toNarrowestIntegerType(value: String): Any = { val bigIntValue = BigDecimal(value) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala index 79b4846cb9..ea28bfa021 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.sql.catalyst.expressions.{Literal, GreaterThan, Not, Attribute} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project, LogicalPlan, Command} +import org.apache.spark.unsafe.types.CalendarInterval private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command { override def output: Seq[Attribute] = Seq.empty @@ -74,4 +75,55 @@ class SqlParserSuite extends PlanTest { OneRowRelation) comparePlans(parsed, expected) } + + test("support hive interval literal") { + def checkInterval(sql: String, result: CalendarInterval): Unit = { + val parsed = SqlParser.parse(sql) + val expected = Project( + UnresolvedAlias( + Literal(result) + ) :: Nil, + OneRowRelation) + comparePlans(parsed, expected) + } + + def checkYearMonth(lit: String): Unit = { + checkInterval( + s"SELECT INTERVAL '$lit' YEAR TO MONTH", + CalendarInterval.fromYearMonthString(lit)) + } + + def checkDayTime(lit: String): Unit = { + checkInterval( + s"SELECT INTERVAL '$lit' DAY TO SECOND", + CalendarInterval.fromDayTimeString(lit)) + } + + def checkSingleUnit(lit: String, unit: String): Unit = { + checkInterval( + s"SELECT INTERVAL '$lit' $unit", + CalendarInterval.fromSingleUnitString(unit, lit)) + } + + checkYearMonth("123-10") + checkYearMonth("496-0") + checkYearMonth("-2-3") + checkYearMonth("-123-0") + + checkDayTime("99 11:22:33.123456789") + checkDayTime("-99 11:22:33.123456789") + checkDayTime("10 9:8:7.123456789") + checkDayTime("1 0:0:0") + checkDayTime("-1 0:0:0") + checkDayTime("1 0:0:1") + + for (unit <- Seq("year", "month", "day", "hour", "minute", "second")) { + checkSingleUnit("7", unit) + checkSingleUnit("-7", unit) + checkSingleUnit("0", unit) + } + + checkSingleUnit("13.123456789", "second") + checkSingleUnit("-13.123456789", "second") + } } -- cgit v1.2.3