diff options
Diffstat (limited to 'sql/catalyst')
3 files changed, 104 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index d687a85c18..f704b0998c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.Origin -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, StructType} /** * Base SQL parsing infrastructure. @@ -49,6 +49,14 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier()) } + /** + * Creates StructType for a given SQL string, which is a comma separated list of field + * definitions which will preserve the correct Hive metadata. + */ + override def parseTableSchema(sqlText: String): StructType = parse(sqlText) { parser => + StructType(astBuilder.visitColTypeList(parser.colTypeList())) + } + /** Creates LogicalPlan for a given SQL string. */ override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser => astBuilder.visitSingleStatement(parser.singleStatement()) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala index 7f35d650b9..6edbe25397 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types.StructType /** * Interface for a parser. @@ -33,4 +34,10 @@ trait ParserInterface { /** Creates TableIdentifier for a given SQL string. */ def parseTableIdentifier(sqlText: String): TableIdentifier + + /** + * Creates StructType for a given SQL string, which is a comma separated list of field + * definitions which will preserve the correct Hive metadata. + */ + def parseTableSchema(sqlText: String): StructType } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala new file mode 100644 index 0000000000..da1041d617 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala @@ -0,0 +1,88 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +class TableSchemaParserSuite extends SparkFunSuite { + + def parse(sql: String): StructType = CatalystSqlParser.parseTableSchema(sql) + + def checkTableSchema(tableSchemaString: String, expectedDataType: DataType): Unit = { + test(s"parse $tableSchemaString") { + assert(parse(tableSchemaString) === expectedDataType) + } + } + + def assertError(sql: String): Unit = + intercept[ParseException](CatalystSqlParser.parseTableSchema(sql)) + + checkTableSchema("a int", new StructType().add("a", "int")) + checkTableSchema("A int", new StructType().add("A", "int")) + checkTableSchema("a INT", new StructType().add("a", "int")) + checkTableSchema("`!@#$%.^&*()` string", new StructType().add("!@#$%.^&*()", "string")) + checkTableSchema("a int, b long", new StructType().add("a", "int").add("b", "long")) + checkTableSchema("a STRUCT<intType: int, ts:timestamp>", + StructType( + StructField("a", StructType( + StructField("intType", IntegerType) :: + StructField("ts", TimestampType) :: Nil)) :: Nil)) + checkTableSchema( + "a int comment 'test'", + new StructType().add("a", "int", nullable = true, "test")) + + test("complex hive type") { + val tableSchemaString = + """ + |complexStructCol struct< + |struct:struct<deciMal:DECimal, anotherDecimal:decimAL(5,2)>, + |MAP:Map<timestamp, varchar(10)>, + |arrAy:Array<double>, + |anotherArray:Array<char(9)>> + """.stripMargin.replace("\n", "") + + val builder = new MetadataBuilder + builder.putString(HIVE_TYPE_STRING, + "struct<struct:struct<deciMal:decimal(10,0),anotherDecimal:decimal(5,2)>," + + "MAP:map<timestamp,varchar(10)>,arrAy:array<double>,anotherArray:array<char(9)>>") + + val expectedDataType = + StructType( + StructField("complexStructCol", StructType( + StructField("struct", + StructType( + StructField("deciMal", DecimalType.USER_DEFAULT) :: + StructField("anotherDecimal", DecimalType(5, 2)) :: Nil)) :: + StructField("MAP", MapType(TimestampType, StringType)) :: + StructField("arrAy", ArrayType(DoubleType)) :: + StructField("anotherArray", ArrayType(StringType)) :: Nil), + nullable = true, + builder.build()) :: Nil) + + assert(parse(tableSchemaString) === expectedDataType) + } + + // Negative cases + assertError("") + assertError("a") + assertError("a INT b long") + assertError("a INT,, b long") + assertError("a INT, b long,,") + assertError("a INT, b long, c int,") +} |