aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2014-04-24 18:21:00 -0700
committerReynold Xin <rxin@apache.org>2014-04-24 18:21:00 -0700
commit4660991e679eda158a3ae8039b686eae197a71d1 (patch)
tree41cbb59c942eb818bbf427daf10342b81944e97a /sql
parent526a518bf32ad55b926a26f16086f445fd0ae29f (diff)
downloadspark-4660991e679eda158a3ae8039b686eae197a71d1.tar.gz
spark-4660991e679eda158a3ae8039b686eae197a71d1.tar.bz2
spark-4660991e679eda158a3ae8039b686eae197a71d1.zip
[SQL] Add support for parsing indexing into arrays in SQL.
Author: Michael Armbrust <michael@databricks.com> Closes #518 from marmbrus/parseArrayIndex and squashes the following commits: afd2d6b [Michael Armbrust] 100 chars c3d6026 [Michael Armbrust] Add support for parsing indexing into arrays in SQL.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/TestData.scala7
3 files changed, 30 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 13a19d0adf..8c76a3aa96 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst
import scala.language.implicitConversions
import scala.util.parsing.combinator.lexical.StdLexical
import scala.util.parsing.combinator.syntactical.StandardTokenParsers
+import scala.util.parsing.combinator.PackratParsers
import scala.util.parsing.input.CharArrayReader.EofCh
import org.apache.spark.sql.catalyst.analysis._
@@ -39,7 +40,7 @@ import org.apache.spark.sql.catalyst.types._
* This is currently included mostly for illustrative purposes. Users wanting more complete support
* for a SQL like language should checkout the HiveQL support in the sql/hive sub-project.
*/
-class SqlParser extends StandardTokenParsers {
+class SqlParser extends StandardTokenParsers with PackratParsers {
def apply(input: String): LogicalPlan = {
phrase(query)(new lexical.Scanner(input)) match {
case Success(r, x) => r
@@ -152,7 +153,7 @@ class SqlParser extends StandardTokenParsers {
lexical.delimiters += (
"@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")",
- ",", ";", "%", "{", "}", ":"
+ ",", ";", "%", "{", "}", ":", "[", "]"
)
protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = {
@@ -339,7 +340,10 @@ class SqlParser extends StandardTokenParsers {
protected lazy val floatLit: Parser[String] =
elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars)
- protected lazy val baseExpression: Parser[Expression] =
+ protected lazy val baseExpression: PackratParser[Expression] =
+ expression ~ "[" ~ expression <~ "]" ^^ {
+ case base ~ _ ~ ordinal => GetItem(base, ordinal)
+ } |
TRUE ^^^ Literal(true, BooleanType) |
FALSE ^^^ Literal(false, BooleanType) |
cast |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 4c4fd6dbbe..dde957d715 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -28,6 +28,22 @@ class SQLQuerySuite extends QueryTest {
// Make sure the tables are loaded.
TestData
+ test("index into array") {
+ checkAnswer(
+ sql("SELECT data, data[0], data[0] + data[1], data[0 + 1] FROM arrayData"),
+ arrayData.map(d => (d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect().toSeq)
+ }
+
+ test("index into array of arrays") {
+ checkAnswer(
+ sql(
+ "SELECT nestedData, nestedData[0][0], nestedData[0][0] + nestedData[0][1] FROM arrayData"),
+ arrayData.map(d =>
+ (d.nestedData,
+ d.nestedData(0)(0),
+ d.nestedData(0)(0) + d.nestedData(0)(1))).collect().toSeq)
+ }
+
test("agg") {
checkAnswer(
sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 271b1d9fca..002b7f0ada 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -66,4 +66,11 @@ object TestData {
LowerCaseData(3, "c") ::
LowerCaseData(4, "d") :: Nil)
lowerCaseData.registerAsTable("lowerCaseData")
+
+ case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])
+ val arrayData =
+ TestSQLContext.sparkContext.parallelize(
+ ArrayData(Seq(1,2,3), Seq(Seq(1,2,3))) ::
+ ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil)
+ arrayData.registerAsTable("arrayData")
}