From e4f4886d7148bf48f9e3462b83bfb1ecc7edbe31 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 10 Sep 2014 12:56:59 -0700 Subject: [SPARK-2096][SQL] Correctly parse dot notations First let me write down the current `projections` grammar of spark sql: expression : orExpression orExpression : andExpression {"or" andExpression} andExpression : comparisonExpression {"and" comparisonExpression} comparisonExpression : termExpression | termExpression "=" termExpression | termExpression ">" termExpression | ... termExpression : productExpression {"+"|"-" productExpression} productExpression : baseExpression {"*"|"/"|"%" baseExpression} baseExpression : expression "[" expression "]" | ... | ident | ... ident : identChar {identChar | digit} | delimiters | ... identChar : letter | "_" | "." delimiters : "," | ";" | "(" | ")" | "[" | "]" | ... projection : expression [["AS"] ident] projections : projection { "," projection} For something like `a.b.c[1]`, it will be parsed as: But for something like `a[1].b`, the current grammar can't parse it correctly. A simple solution is written in `ParquetQuerySuite#NestedSqlParser`, changed grammars are: delimiters : "." | "," | ";" | "(" | ")" | "[" | "]" | ... identChar : letter | "_" baseExpression : expression "[" expression "]" | expression "." ident | ... | ident | ... This works well, but can't cover some corner case like `select t.a.b from table as t`: `t.a.b` parsed as `GetField(GetField(UnResolved("t"), "a"), "b")` instead of `GetField(UnResolved("t.a"), "b")` using this new grammar. However, we can't resolve `t` as it's not a filed, but the whole table.(if we could do this, then `select t from table as t` is legal, which is unexpected) My solution is: dotExpressionHeader : ident "." ident baseExpression : expression "[" expression "]" | expression "." ident | ... | dotExpressionHeader | ident | ... I passed all test cases under sql locally and add a more complex case. "arrayOfStruct.field1 to access all values of field1" is not supported yet. Since this PR has changed a lot of code, I will open another PR for it. I'm not familiar with the latter optimize phase, please correct me if I missed something. Author: Wenchen Fan Author: Michael Armbrust Closes #2230 from cloud-fan/dot and squashes the following commits: e1a8898 [Wenchen Fan] remove support for arbitrary nested arrays ee8a724 [Wenchen Fan] rollback LogicalPlan, support dot operation on nested array type a58df40 [Michael Armbrust] add regression test for doubly nested data 16bc4c6 [Wenchen Fan] some enhance 95d733f [Wenchen Fan] split long line dc31698 [Wenchen Fan] SPARK-2096 Correctly parse dot notations --- .../org/apache/spark/sql/catalyst/SqlParser.scala | 13 ++- .../sql/catalyst/plans/logical/LogicalPlan.scala | 6 +- .../org/apache/spark/sql/json/JsonSuite.scala | 14 +++ .../org/apache/spark/sql/json/TestJsonData.scala | 26 ++++++ .../spark/sql/parquet/ParquetQuerySuite.scala | 102 +++++---------------- .../spark/sql/hive/execution/SQLQuerySuite.scala | 17 +++- 6 files changed, 88 insertions(+), 90 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index a04b4a938d..ca69531c69 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -357,16 +357,25 @@ class SqlParser extends StandardTokenParsers with PackratParsers { expression ~ "[" ~ expression <~ "]" ^^ { case base ~ _ ~ ordinal => GetItem(base, ordinal) } | + (expression <~ ".") ~ ident ^^ { + case base ~ fieldName => GetField(base, fieldName) + } | TRUE ^^^ Literal(true, BooleanType) | FALSE ^^^ Literal(false, BooleanType) | cast | "(" ~> expression <~ ")" | function | "-" ~> literal ^^ UnaryMinus | + dotExpressionHeader | ident ^^ UnresolvedAttribute | "*" ^^^ Star(None) | literal + protected lazy val dotExpressionHeader: Parser[Expression] = + (ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ { + case i1 ~ i2 ~ rest => UnresolvedAttribute(i1 + "." + i2 + rest.mkString(".", ".", "")) + } + protected lazy val dataType: Parser[DataType] = STRING ^^^ StringType | TIMESTAMP ^^^ TimestampType } @@ -380,7 +389,7 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical { delimiters += ( "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", - ",", ";", "%", "{", "}", ":", "[", "]" + ",", ";", "%", "{", "}", ":", "[", "]", "." ) override lazy val token: Parser[Token] = ( @@ -401,7 +410,7 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical { | failure("illegal character") ) - override def identChar = letter | elem('_') | elem('.') + override def identChar = letter | elem('_') override def whitespace: Parser[Any] = rep( whitespaceChar diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index f81d911194..bae491f07c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -104,11 +104,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { case Seq((a, Nil)) => Some(a) // One match, no nested fields, use it. // One match, but we also need to extract the requested nested field. case Seq((a, nestedFields)) => - a.dataType match { - case StructType(fields) => - Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)()) - case _ => None // Don't know how to resolve these field references - } + Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)()) case Seq() => None // No matches. case ambiguousReferences => throw new TreeNodeException( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 05513a1271..301d482d27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -581,4 +581,18 @@ class JsonSuite extends QueryTest { "this is a simple string.") :: Nil ) } + + test("SPARK-2096 Correctly parse dot notations") { + val jsonSchemaRDD = jsonRDD(complexFieldAndType2) + jsonSchemaRDD.registerTempTable("jsonTable") + + checkAnswer( + sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), + (true, "str1") :: Nil + ) + checkAnswer( + sql("select complexArrayOfStruct[0].field1[1].inner2[0], complexArrayOfStruct[1].field2[0][1] from jsonTable"), + ("str2", 6) :: Nil + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index a88310b5f1..b3f95f08e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -82,4 +82,30 @@ object TestJsonData { """{"c":[33, 44]}""" :: """{"d":{"field":true}}""" :: """{"e":"str"}""" :: Nil) + + val complexFieldAndType2 = + TestSQLContext.sparkContext.parallelize( + """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], + "complexArrayOfStruct": [ + { + "field1": [ + { + "inner1": "str1" + }, + { + "inner2": ["str2", "str22"] + }], + "field2": [[1, 2], [3, 4]] + }, + { + "field1": [ + { + "inner2": ["str3", "str33"] + }, + { + "inner1": "str4" + }], + "field2": [[5, 6], [7, 8]] + }] + }""" :: Nil) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 42923b6a28..b0a06cd3ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,19 +17,14 @@ package org.apache.spark.sql.parquet +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.mapreduce.Job import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} - import parquet.hadoop.ParquetFileWriter import parquet.hadoop.util.ContextUtil -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.mapreduce.Job - -import org.apache.spark.SparkContext import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{SqlLexical, SqlParser} -import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.types.{BooleanType, IntegerType} +import org.apache.spark.sql.catalyst.types.IntegerType import org.apache.spark.sql.catalyst.util.getTempFilePath import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ @@ -87,11 +82,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA var testRDD: SchemaRDD = null - // TODO: remove this once SqlParser can parse nested select statements - var nestedParserSqlContext: NestedParserSQLContext = null - override def beforeAll() { - nestedParserSqlContext = new NestedParserSQLContext(TestSQLContext.sparkContext) ParquetTestData.writeFile() ParquetTestData.writeFilterFile() ParquetTestData.writeNestedFile1() @@ -718,11 +709,9 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("Projection in addressbook") { - val data = nestedParserSqlContext - .parquetFile(ParquetTestData.testNestedDir1.toString) - .toSchemaRDD + val data = parquetFile(ParquetTestData.testNestedDir1.toString).toSchemaRDD data.registerTempTable("data") - val query = nestedParserSqlContext.sql("SELECT owner, contacts[1].name FROM data") + val query = sql("SELECT owner, contacts[1].name FROM data") val tmp = query.collect() assert(tmp.size === 2) assert(tmp(0).size === 2) @@ -733,21 +722,19 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("Simple query on nested int data") { - val data = nestedParserSqlContext - .parquetFile(ParquetTestData.testNestedDir2.toString) - .toSchemaRDD + val data = parquetFile(ParquetTestData.testNestedDir2.toString).toSchemaRDD data.registerTempTable("data") - val result1 = nestedParserSqlContext.sql("SELECT entries[0].value FROM data").collect() + val result1 = sql("SELECT entries[0].value FROM data").collect() assert(result1.size === 1) assert(result1(0).size === 1) assert(result1(0)(0) === 2.5) - val result2 = nestedParserSqlContext.sql("SELECT entries[0] FROM data").collect() + val result2 = sql("SELECT entries[0] FROM data").collect() assert(result2.size === 1) val subresult1 = result2(0)(0).asInstanceOf[CatalystConverter.StructScalaType[_]] assert(subresult1.size === 2) assert(subresult1(0) === 2.5) assert(subresult1(1) === false) - val result3 = nestedParserSqlContext.sql("SELECT outerouter FROM data").collect() + val result3 = sql("SELECT outerouter FROM data").collect() val subresult2 = result3(0)(0) .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) .asInstanceOf[CatalystConverter.ArrayScalaType[_]] @@ -760,19 +747,18 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("nested structs") { - val data = nestedParserSqlContext - .parquetFile(ParquetTestData.testNestedDir3.toString) + val data = parquetFile(ParquetTestData.testNestedDir3.toString) .toSchemaRDD data.registerTempTable("data") - val result1 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect() + val result1 = sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect() assert(result1.size === 1) assert(result1(0).size === 1) assert(result1(0)(0) === false) - val result2 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[0].value[1].truth FROM data").collect() + val result2 = sql("SELECT booleanNumberPairs[0].value[1].truth FROM data").collect() assert(result2.size === 1) assert(result2(0).size === 1) assert(result2(0)(0) === true) - val result3 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[1].value[0].truth FROM data").collect() + val result3 = sql("SELECT booleanNumberPairs[1].value[0].truth FROM data").collect() assert(result3.size === 1) assert(result3(0).size === 1) assert(result3(0)(0) === false) @@ -796,11 +782,9 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("map with struct values") { - val data = nestedParserSqlContext - .parquetFile(ParquetTestData.testNestedDir4.toString) - .toSchemaRDD + val data = parquetFile(ParquetTestData.testNestedDir4.toString).toSchemaRDD data.registerTempTable("mapTable") - val result1 = nestedParserSqlContext.sql("SELECT data2 FROM mapTable").collect() + val result1 = sql("SELECT data2 FROM mapTable").collect() assert(result1.size === 1) val entry1 = result1(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] @@ -814,7 +798,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(entry2 != null) assert(entry2(0) === 49) assert(entry2(1) === null) - val result2 = nestedParserSqlContext.sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM mapTable""").collect() + val result2 = sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM mapTable""").collect() assert(result2.size === 1) assert(result2(0)(0) === 42.toLong) assert(result2(0)(1) === "the answer") @@ -825,15 +809,12 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA // has no effect in this test case val tmpdir = Utils.createTempDir() Utils.deleteRecursively(tmpdir) - val result = nestedParserSqlContext - .parquetFile(ParquetTestData.testNestedDir1.toString) - .toSchemaRDD + val result = parquetFile(ParquetTestData.testNestedDir1.toString).toSchemaRDD result.saveAsParquetFile(tmpdir.toString) - nestedParserSqlContext - .parquetFile(tmpdir.toString) + parquetFile(tmpdir.toString) .toSchemaRDD .registerTempTable("tmpcopy") - val tmpdata = nestedParserSqlContext.sql("SELECT owner, contacts[1].name FROM tmpcopy").collect() + val tmpdata = sql("SELECT owner, contacts[1].name FROM tmpcopy").collect() assert(tmpdata.size === 2) assert(tmpdata(0).size === 2) assert(tmpdata(0)(0) === "Julien Le Dem") @@ -844,20 +825,17 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("Writing out Map and reading it back in") { - val data = nestedParserSqlContext - .parquetFile(ParquetTestData.testNestedDir4.toString) - .toSchemaRDD + val data = parquetFile(ParquetTestData.testNestedDir4.toString).toSchemaRDD val tmpdir = Utils.createTempDir() Utils.deleteRecursively(tmpdir) data.saveAsParquetFile(tmpdir.toString) - nestedParserSqlContext - .parquetFile(tmpdir.toString) + parquetFile(tmpdir.toString) .toSchemaRDD .registerTempTable("tmpmapcopy") - val result1 = nestedParserSqlContext.sql("""SELECT data1["key2"] FROM tmpmapcopy""").collect() + val result1 = sql("""SELECT data1["key2"] FROM tmpmapcopy""").collect() assert(result1.size === 1) assert(result1(0)(0) === 2) - val result2 = nestedParserSqlContext.sql("SELECT data2 FROM tmpmapcopy").collect() + val result2 = sql("SELECT data2 FROM tmpmapcopy").collect() assert(result2.size === 1) val entry1 = result2(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] @@ -871,42 +849,10 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(entry2 != null) assert(entry2(0) === 49) assert(entry2(1) === null) - val result3 = nestedParserSqlContext.sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM tmpmapcopy""").collect() + val result3 = sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM tmpmapcopy""").collect() assert(result3.size === 1) assert(result3(0)(0) === 42.toLong) assert(result3(0)(1) === "the answer") Utils.deleteRecursively(tmpdir) } } - -// TODO: the code below is needed temporarily until the standard parser is able to parse -// nested field expressions correctly -class NestedParserSQLContext(@transient override val sparkContext: SparkContext) extends SQLContext(sparkContext) { - override protected[sql] val parser = new NestedSqlParser() -} - -class NestedSqlLexical(override val keywords: Seq[String]) extends SqlLexical(keywords) { - override def identChar = letter | elem('_') - delimiters += (".") -} - -class NestedSqlParser extends SqlParser { - override val lexical = new NestedSqlLexical(reservedWords) - - override protected lazy val baseExpression: PackratParser[Expression] = - expression ~ "[" ~ expression <~ "]" ^^ { - case base ~ _ ~ ordinal => GetItem(base, ordinal) - } | - expression ~ "." ~ ident ^^ { - case base ~ _ ~ fieldName => GetField(base, fieldName) - } | - TRUE ^^^ Literal(true, BooleanType) | - FALSE ^^^ Literal(false, BooleanType) | - cast | - "(" ~> expression <~ ")" | - function | - "-" ~> literal ^^ UnaryMinus | - ident ^^ UnresolvedAttribute | - "*" ^^^ Star(None) | - literal -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 635a9fb0d5..b99caf77bc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.hive.execution -import scala.reflect.ClassTag - -import org.apache.spark.sql.{SQLConf, QueryTest} -import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.QueryTest import org.apache.spark.sql.hive.test.TestHive._ +case class Nested1(f1: Nested2) +case class Nested2(f2: Nested3) +case class Nested3(f3: Int) + /** * A collection of hive query tests where we generate the answers ourselves instead of depending on * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is @@ -47,4 +47,11 @@ class SQLQuerySuite extends QueryTest { GROUP BY key, value ORDER BY value) a""").collect().toSeq) } + + test("double nested data") { + sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil).registerTempTable("nested") + checkAnswer( + sql("SELECT f1.f2.f3 FROM nested"), + 1) + } } -- cgit v1.2.3