diff options
5 files changed, 309 insertions, 0 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 7d7b4b9167..c3f2935010 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -45,6 +45,7 @@ import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable, HiveScriptIOSchema} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.random.RandomSampler /* Implicit conversions */ @@ -1519,6 +1520,30 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case ast: ASTNode if ast.getType == HiveParser.TOK_CHARSETLITERAL => Literal(BaseSemanticAnalyzer.charSetString(ast.getChild(0).getText, ast.getChild(1).getText)) + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => + Literal(CalendarInterval.fromYearMonthString(ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_DAY_TIME_LITERAL => + Literal(CalendarInterval.fromDayTimeString(ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_YEAR_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("year", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_MONTH_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("month", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_DAY_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("day", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_HOUR_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("hour", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_MINUTE_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("minute", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_SECOND_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("second", ast.getText)) + case a: ASTNode => throw new NotImplementedError( s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText} : diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index f765395e14..79cf40aba4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -175,4 +175,19 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { assert(desc.serde == Option("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe")) assert(desc.properties == Map(("tbl_p1" -> "p11"), ("tbl_p2" -> "p22"))) } + + test("Invalid interval term should throw AnalysisException") { + def assertError(sql: String, errorMessage: String): Unit = { + val e = intercept[AnalysisException] { + HiveQl.parseSql(sql) + } + assert(e.getMessage.contains(errorMessage)) + } + assertError("select interval '42-32' year to month", + "month 32 outside range [0, 11]") + assertError("select interval '5 49:12:15' day to second", + "hour 49 outside range [0, 23]") + assertError("select interval '.1111111111' second", + "nanosecond 1111111111 outside range") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 1dff07a6de..2fa7ae3fa2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval case class Nested1(f1: Nested2) case class Nested2(f2: Nested3) @@ -1115,4 +1116,25 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) } + + test("Convert hive interval term into Literal of CalendarIntervalType") { + checkAnswer(sql("select interval '10-9' year to month"), + Row(CalendarInterval.fromString("interval 10 years 9 months"))) + checkAnswer(sql("select interval '20 15:40:32.99899999' day to second"), + Row(CalendarInterval.fromString("interval 2 weeks 6 days 15 hours 40 minutes " + + "32 seconds 99 milliseconds 899 microseconds"))) + checkAnswer(sql("select interval '30' year"), + Row(CalendarInterval.fromString("interval 30 years"))) + checkAnswer(sql("select interval '25' month"), + Row(CalendarInterval.fromString("interval 25 months"))) + checkAnswer(sql("select interval '-100' day"), + Row(CalendarInterval.fromString("interval -14 weeks -2 days"))) + checkAnswer(sql("select interval '40' hour"), + Row(CalendarInterval.fromString("interval 1 days 16 hours"))) + checkAnswer(sql("select interval '80' minute"), + Row(CalendarInterval.fromString("interval 1 hour 20 minutes"))) + checkAnswer(sql("select interval '299.889987299' second"), + Row(CalendarInterval.fromString( + "interval 4 minutes 59 seconds 889 milliseconds 987 microseconds"))) + } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java index 92a5e4f86f..30e1758076 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java @@ -50,6 +50,14 @@ public final class CalendarInterval implements Serializable { unitRegex("week") + unitRegex("day") + unitRegex("hour") + unitRegex("minute") + unitRegex("second") + unitRegex("millisecond") + unitRegex("microsecond")); + private static Pattern yearMonthPattern = + Pattern.compile("^(?:['|\"])?([+|-])?(\\d+)-(\\d+)(?:['|\"])?$"); + + private static Pattern dayTimePattern = + Pattern.compile("^(?:['|\"])?([+|-])?(\\d+) (\\d+):(\\d+):(\\d+)(\\.(\\d+))?(?:['|\"])?$"); + + private static Pattern quoteTrimPattern = Pattern.compile("^(?:['|\"])?(.*?)(?:['|\"])?$"); + private static long toLong(String s) { if (s == null) { return 0; @@ -79,6 +87,154 @@ public final class CalendarInterval implements Serializable { } } + public static long toLongWithRange(String fieldName, + String s, long minValue, long maxValue) throws IllegalArgumentException { + long result = 0; + if (s != null) { + result = Long.valueOf(s); + if (result < minValue || result > maxValue) { + throw new IllegalArgumentException(String.format("%s %d outside range [%d, %d]", + fieldName, result, minValue, maxValue)); + } + } + return result; + } + + /** + * Parse YearMonth string in form: [-]YYYY-MM + * + * adapted from HiveIntervalYearMonth.valueOf + */ + public static CalendarInterval fromYearMonthString(String s) throws IllegalArgumentException { + CalendarInterval result = null; + if (s == null) { + throw new IllegalArgumentException("Interval year-month string was null"); + } + s = s.trim(); + Matcher m = yearMonthPattern.matcher(s); + if (!m.matches()) { + throw new IllegalArgumentException( + "Interval string does not match year-month format of 'y-m': " + s); + } else { + try { + int sign = m.group(1) != null && m.group(1).equals("-") ? -1 : 1; + int years = (int) toLongWithRange("year", m.group(2), 0, Integer.MAX_VALUE); + int months = (int) toLongWithRange("month", m.group(3), 0, 11); + result = new CalendarInterval(sign * (years * 12 + months), 0); + } catch (Exception e) { + throw new IllegalArgumentException( + "Error parsing interval year-month string: " + e.getMessage(), e); + } + } + return result; + } + + /** + * Parse dayTime string in form: [-]d HH:mm:ss.nnnnnnnnn + * + * adapted from HiveIntervalDayTime.valueOf + */ + public static CalendarInterval fromDayTimeString(String s) throws IllegalArgumentException { + CalendarInterval result = null; + if (s == null) { + throw new IllegalArgumentException("Interval day-time string was null"); + } + s = s.trim(); + Matcher m = dayTimePattern.matcher(s); + if (!m.matches()) { + throw new IllegalArgumentException( + "Interval string does not match day-time format of 'd h:m:s.n': " + s); + } else { + try { + int sign = m.group(1) != null && m.group(1).equals("-") ? -1 : 1; + long days = toLongWithRange("day", m.group(2), 0, Integer.MAX_VALUE); + long hours = toLongWithRange("hour", m.group(3), 0, 23); + long minutes = toLongWithRange("minute", m.group(4), 0, 59); + long seconds = toLongWithRange("second", m.group(5), 0, 59); + // Hive allow nanosecond precision interval + long nanos = toLongWithRange("nanosecond", m.group(7), 0L, 999999999L); + result = new CalendarInterval(0, sign * ( + days * MICROS_PER_DAY + hours * MICROS_PER_HOUR + minutes * MICROS_PER_MINUTE + + seconds * MICROS_PER_SECOND + nanos / 1000L)); + } catch (Exception e) { + throw new IllegalArgumentException( + "Error parsing interval day-time string: " + e.getMessage(), e); + } + } + return result; + } + + public static CalendarInterval fromSingleUnitString(String unit, String s) + throws IllegalArgumentException { + + CalendarInterval result = null; + if (s == null) { + throw new IllegalArgumentException(String.format("Interval %s string was null", unit)); + } + s = s.trim(); + Matcher m = quoteTrimPattern.matcher(s); + if (!m.matches()) { + throw new IllegalArgumentException( + "Interval string does not match day-time format of 'd h:m:s.n': " + s); + } else { + try { + if (unit.equals("year")) { + int year = (int) toLongWithRange("year", m.group(1), + Integer.MIN_VALUE / 12, Integer.MAX_VALUE / 12); + result = new CalendarInterval(year * 12, 0L); + + } else if (unit.equals("month")) { + int month = (int) toLongWithRange("month", m.group(1), + Integer.MIN_VALUE, Integer.MAX_VALUE); + result = new CalendarInterval(month, 0L); + + } else if (unit.equals("day")) { + long day = toLongWithRange("day", m.group(1), + Long.MIN_VALUE / MICROS_PER_DAY, Long.MAX_VALUE / MICROS_PER_DAY); + result = new CalendarInterval(0, day * MICROS_PER_DAY); + + } else if (unit.equals("hour")) { + long hour = toLongWithRange("hour", m.group(1), + Long.MIN_VALUE / MICROS_PER_HOUR, Long.MAX_VALUE / MICROS_PER_HOUR); + result = new CalendarInterval(0, hour * MICROS_PER_HOUR); + + } else if (unit.equals("minute")) { + long minute = toLongWithRange("minute", m.group(1), + Long.MIN_VALUE / MICROS_PER_MINUTE, Long.MAX_VALUE / MICROS_PER_MINUTE); + result = new CalendarInterval(0, minute * MICROS_PER_MINUTE); + + } else if (unit.equals("second")) { + long micros = parseSecondNano(m.group(1)); + result = new CalendarInterval(0, micros); + } + } catch (Exception e) { + throw new IllegalArgumentException("Error parsing interval string: " + e.getMessage(), e); + } + } + return result; + } + + /** + * Parse second_nano string in ss.nnnnnnnnn format to microseconds + */ + public static long parseSecondNano(String secondNano) throws IllegalArgumentException { + String[] parts = secondNano.split("\\."); + if (parts.length == 1) { + return toLongWithRange("second", parts[0], Long.MIN_VALUE / MICROS_PER_SECOND, + Long.MAX_VALUE / MICROS_PER_SECOND) * MICROS_PER_SECOND; + + } else if (parts.length == 2) { + long seconds = parts[0].equals("") ? 0L : toLongWithRange("second", parts[0], + Long.MIN_VALUE / MICROS_PER_SECOND, Long.MAX_VALUE / MICROS_PER_SECOND); + long nanos = toLongWithRange("nanosecond", parts[1], 0L, 999999999L); + return seconds * MICROS_PER_SECOND + nanos / 1000L; + + } else { + throw new IllegalArgumentException( + "Interval string does not match second-nano format of ss.nnnnnnnnn"); + } + } + public final int months; public final long microseconds; diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java index 6274b92b47..80d4982c4b 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java @@ -102,6 +102,97 @@ public class CalendarIntervalSuite { } @Test + public void fromYearMonthStringTest() { + String input; + CalendarInterval i; + + input = "99-10"; + i = new CalendarInterval(99 * 12 + 10, 0L); + assertEquals(CalendarInterval.fromYearMonthString(input), i); + + input = "-8-10"; + i = new CalendarInterval(-8 * 12 - 10, 0L); + assertEquals(CalendarInterval.fromYearMonthString(input), i); + + try { + input = "99-15"; + CalendarInterval.fromYearMonthString(input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("month 15 outside range")); + } + } + + @Test + public void fromDayTimeStringTest() { + String input; + CalendarInterval i; + + input = "5 12:40:30.999999999"; + i = new CalendarInterval(0, 5 * MICROS_PER_DAY + 12 * MICROS_PER_HOUR + + 40 * MICROS_PER_MINUTE + 30 * MICROS_PER_SECOND + 999999L); + assertEquals(CalendarInterval.fromDayTimeString(input), i); + + input = "10 0:12:0.888"; + i = new CalendarInterval(0, 10 * MICROS_PER_DAY + 12 * MICROS_PER_MINUTE); + assertEquals(CalendarInterval.fromDayTimeString(input), i); + + input = "-3 0:0:0"; + i = new CalendarInterval(0, -3 * MICROS_PER_DAY); + assertEquals(CalendarInterval.fromDayTimeString(input), i); + + try { + input = "5 30:12:20"; + CalendarInterval.fromDayTimeString(input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("hour 30 outside range")); + } + + try { + input = "5 30-12"; + CalendarInterval.fromDayTimeString(input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("not match day-time format")); + } + } + + @Test + public void fromSingleUnitStringTest() { + String input; + CalendarInterval i; + + input = "12"; + i = new CalendarInterval(12 * 12, 0L); + assertEquals(CalendarInterval.fromSingleUnitString("year", input), i); + + input = "100"; + i = new CalendarInterval(0, 100 * MICROS_PER_DAY); + assertEquals(CalendarInterval.fromSingleUnitString("day", input), i); + + input = "1999.38888"; + i = new CalendarInterval(0, 1999 * MICROS_PER_SECOND + 38); + assertEquals(CalendarInterval.fromSingleUnitString("second", input), i); + + try { + input = String.valueOf(Integer.MAX_VALUE); + CalendarInterval.fromSingleUnitString("year", input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("outside range")); + } + + try { + input = String.valueOf(Long.MAX_VALUE / MICROS_PER_HOUR + 1); + CalendarInterval.fromSingleUnitString("hour", input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("outside range")); + } + } + + @Test public void addTest() { String input = "interval 3 month 1 hour"; String input2 = "interval 2 month 100 hour"; |