From 4660991e679eda158a3ae8039b686eae197a71d1 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 24 Apr 2014 18:21:00 -0700 Subject: [SQL] Add support for parsing indexing into arrays in SQL. Author: Michael Armbrust Closes #518 from marmbrus/parseArrayIndex and squashes the following commits: afd2d6b [Michael Armbrust] 100 chars c3d6026 [Michael Armbrust] Add support for parsing indexing into arrays in SQL. --- .../scala/org/apache/spark/sql/catalyst/SqlParser.scala | 10 +++++++--- .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 16 ++++++++++++++++ .../src/test/scala/org/apache/spark/sql/TestData.scala | 7 +++++++ 3 files changed, 30 insertions(+), 3 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 13a19d0adf..8c76a3aa96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst import scala.language.implicitConversions import scala.util.parsing.combinator.lexical.StdLexical import scala.util.parsing.combinator.syntactical.StandardTokenParsers +import scala.util.parsing.combinator.PackratParsers import scala.util.parsing.input.CharArrayReader.EofCh import org.apache.spark.sql.catalyst.analysis._ @@ -39,7 +40,7 @@ import org.apache.spark.sql.catalyst.types._ * This is currently included mostly for illustrative purposes. Users wanting more complete support * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project. */ -class SqlParser extends StandardTokenParsers { +class SqlParser extends StandardTokenParsers with PackratParsers { def apply(input: String): LogicalPlan = { phrase(query)(new lexical.Scanner(input)) match { case Success(r, x) => r @@ -152,7 +153,7 @@ class SqlParser extends StandardTokenParsers { lexical.delimiters += ( "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", - ",", ";", "%", "{", "}", ":" + ",", ";", "%", "{", "}", ":", "[", "]" ) protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = { @@ -339,7 +340,10 @@ class SqlParser extends StandardTokenParsers { protected lazy val floatLit: Parser[String] = elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars) - protected lazy val baseExpression: Parser[Expression] = + protected lazy val baseExpression: PackratParser[Expression] = + expression ~ "[" ~ expression <~ "]" ^^ { + case base ~ _ ~ ordinal => GetItem(base, ordinal) + } | TRUE ^^^ Literal(true, BooleanType) | FALSE ^^^ Literal(false, BooleanType) | cast | diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 4c4fd6dbbe..dde957d715 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -28,6 +28,22 @@ class SQLQuerySuite extends QueryTest { // Make sure the tables are loaded. TestData + test("index into array") { + checkAnswer( + sql("SELECT data, data[0], data[0] + data[1], data[0 + 1] FROM arrayData"), + arrayData.map(d => (d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect().toSeq) + } + + test("index into array of arrays") { + checkAnswer( + sql( + "SELECT nestedData, nestedData[0][0], nestedData[0][0] + nestedData[0][1] FROM arrayData"), + arrayData.map(d => + (d.nestedData, + d.nestedData(0)(0), + d.nestedData(0)(0) + d.nestedData(0)(1))).collect().toSeq) + } + test("agg") { checkAnswer( sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 271b1d9fca..002b7f0ada 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -66,4 +66,11 @@ object TestData { LowerCaseData(3, "c") :: LowerCaseData(4, "d") :: Nil) lowerCaseData.registerAsTable("lowerCaseData") + + case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) + val arrayData = + TestSQLContext.sparkContext.parallelize( + ArrayData(Seq(1,2,3), Seq(Seq(1,2,3))) :: + ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil) + arrayData.registerAsTable("arrayData") } -- cgit v1.2.3