aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g412
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala32
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala152
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala2
7 files changed, 195 insertions, 23 deletions
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index 6a94def65f..a3bbaceca3 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -584,7 +584,7 @@ intervalValue
dataType
: complex=ARRAY '<' dataType '>' #complexDataType
| complex=MAP '<' dataType ',' dataType '>' #complexDataType
- | complex=STRUCT ('<' colTypeList? '>' | NEQ) #complexDataType
+ | complex=STRUCT ('<' complexColTypeList? '>' | NEQ) #complexDataType
| identifier ('(' INTEGER_VALUE (',' INTEGER_VALUE)* ')')? #primitiveDataType
;
@@ -593,7 +593,15 @@ colTypeList
;
colType
- : identifier ':'? dataType (COMMENT STRING)?
+ : identifier dataType (COMMENT STRING)?
+ ;
+
+complexColTypeList
+ : complexColType (',' complexColType)*
+ ;
+
+complexColType
+ : identifier ':' dataType (COMMENT STRING)?
;
whenClause
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index bf3f30279a..929c1c4f2d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -316,7 +316,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
// Create the attributes.
val (attributes, schemaLess) = if (colTypeList != null) {
// Typed return columns.
- (createStructType(colTypeList).toAttributes, false)
+ (createSchema(colTypeList).toAttributes, false)
} else if (identifierSeq != null) {
// Untyped return columns.
val attrs = visitIdentifierSeq(identifierSeq).map { name =>
@@ -1450,14 +1450,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
case SqlBaseParser.MAP =>
MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1)))
case SqlBaseParser.STRUCT =>
- createStructType(ctx.colTypeList())
+ createStructType(ctx.complexColTypeList())
}
}
/**
- * Create a [[StructType]] from a sequence of [[StructField]]s.
+ * Create top level table schema.
*/
- protected def createStructType(ctx: ColTypeListContext): StructType = {
+ protected def createSchema(ctx: ColTypeListContext): StructType = {
StructType(Option(ctx).toSeq.flatMap(visitColTypeList))
}
@@ -1476,4 +1476,28 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
val structField = StructField(identifier.getText, typedVisit(dataType), nullable = true)
if (STRING == null) structField else structField.withComment(string(STRING))
}
+
+ /**
+ * Create a [[StructType]] from a sequence of [[StructField]]s.
+ */
+ protected def createStructType(ctx: ComplexColTypeListContext): StructType = {
+ StructType(Option(ctx).toSeq.flatMap(visitComplexColTypeList))
+ }
+
+ /**
+ * Create a [[StructType]] from a number of column definitions.
+ */
+ override def visitComplexColTypeList(
+ ctx: ComplexColTypeListContext): Seq[StructField] = withOrigin(ctx) {
+ ctx.complexColType().asScala.map(visitComplexColType)
+ }
+
+ /**
+ * Create a [[StructField]] from a column definition.
+ */
+ override def visitComplexColType(ctx: ComplexColTypeContext): StructField = withOrigin(ctx) {
+ import ctx._
+ val structField = StructField(identifier.getText, typedVisit(dataType), nullable = true)
+ if (STRING == null) structField else structField.withComment(string(STRING))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
index 020fb16f6f..3964fa3924 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
@@ -116,6 +116,7 @@ class DataTypeParserSuite extends SparkFunSuite {
unsupported("it is not a data type")
unsupported("struct<x+y: int, 1.1:timestamp>")
unsupported("struct<x: int")
+ unsupported("struct<x int, y string>")
// DataType parser accepts certain reserved keywords.
checkDataType(
@@ -125,16 +126,11 @@ class DataTypeParserSuite extends SparkFunSuite {
StructField("DATE", BooleanType, true) :: Nil)
)
- // Define struct columns without ':'
- checkDataType(
- "struct<x int, y string>",
- (new StructType).add("x", IntegerType).add("y", StringType))
-
- checkDataType(
- "struct<`x``y` int>",
- (new StructType).add("x`y", IntegerType))
-
// Use SQL keywords.
checkDataType("struct<end: long, select: int, from: string>",
(new StructType).add("end", LongType).add("select", IntegerType).add("from", StringType))
+
+ // DataType parser accepts comments.
+ checkDataType("Struct<x: INT, y: STRING COMMENT 'test'>",
+ (new StructType).add("x", IntegerType).add("y", StringType, true, "test"))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index 085bb9fc3c..5f87b71210 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -340,7 +340,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder {
if (provider.toLowerCase == "hive") {
throw new AnalysisException("Cannot create hive serde table with CREATE TABLE USING")
}
- val schema = Option(ctx.colTypeList()).map(createStructType)
+ val schema = Option(ctx.colTypeList()).map(createSchema)
val partitionColumnNames =
Option(ctx.partitionColumnNames)
.map(visitIdentifierList(_).toArray)
@@ -399,7 +399,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder {
ctx: CreateTempViewUsingContext): LogicalPlan = withOrigin(ctx) {
CreateTempViewUsing(
tableIdent = visitTableIdentifier(ctx.tableIdentifier()),
- userSpecifiedSchema = Option(ctx.colTypeList()).map(createStructType),
+ userSpecifiedSchema = Option(ctx.colTypeList()).map(createSchema),
replace = ctx.REPLACE != null,
provider = ctx.tableProvider.qualifiedName.getText,
options = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
index 6712d32924..e0976ae950 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
@@ -17,13 +17,17 @@
package org.apache.spark.sql.execution
+import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.{DescribeFunctionCommand, DescribeTableCommand,
ShowFunctionsCommand}
-import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.execution.datasources.{CreateTable, CreateTempViewUsing}
+import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
+import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}
/**
* Parser test cases for rules defined in [[SparkSqlParser]].
@@ -35,8 +39,23 @@ class SparkSqlParserSuite extends PlanTest {
private lazy val parser = new SparkSqlParser(new SQLConf)
+ /**
+ * Normalizes plans:
+ * - CreateTable the createTime in tableDesc will replaced by -1L.
+ */
+ private def normalizePlan(plan: LogicalPlan): LogicalPlan = {
+ plan match {
+ case CreateTable(tableDesc, mode, query) =>
+ val newTableDesc = tableDesc.copy(createTime = -1L)
+ CreateTable(newTableDesc, mode, query)
+ case _ => plan // Don't transform
+ }
+ }
+
private def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = {
- comparePlans(parser.parsePlan(sqlCommand), plan)
+ val normalized1 = normalizePlan(parser.parsePlan(sqlCommand))
+ val normalized2 = normalizePlan(plan)
+ comparePlans(normalized1, normalized2)
}
private def intercept(sqlCommand: String, messages: String*): Unit = {
@@ -68,9 +87,134 @@ class SparkSqlParserSuite extends PlanTest {
DescribeFunctionCommand(FunctionIdentifier("bar", database = None), isExtended = true))
assertEqual("describe function foo.bar",
DescribeFunctionCommand(
- FunctionIdentifier("bar", database = Option("foo")), isExtended = false))
+ FunctionIdentifier("bar", database = Some("foo")), isExtended = false))
assertEqual("describe function extended f.bar",
- DescribeFunctionCommand(FunctionIdentifier("bar", database = Option("f")), isExtended = true))
+ DescribeFunctionCommand(FunctionIdentifier("bar", database = Some("f")), isExtended = true))
+ }
+
+ private def createTableUsing(
+ table: String,
+ database: Option[String] = None,
+ tableType: CatalogTableType = CatalogTableType.MANAGED,
+ storage: CatalogStorageFormat = CatalogStorageFormat.empty,
+ schema: StructType = new StructType,
+ provider: Option[String] = Some("parquet"),
+ partitionColumnNames: Seq[String] = Seq.empty,
+ bucketSpec: Option[BucketSpec] = None,
+ mode: SaveMode = SaveMode.ErrorIfExists,
+ query: Option[LogicalPlan] = None): CreateTable = {
+ CreateTable(
+ CatalogTable(
+ identifier = TableIdentifier(table, database),
+ tableType = tableType,
+ storage = storage,
+ schema = schema,
+ provider = provider,
+ partitionColumnNames = partitionColumnNames,
+ bucketSpec = bucketSpec
+ ), mode, query
+ )
+ }
+
+ private def createTempViewUsing(
+ table: String,
+ database: Option[String] = None,
+ schema: Option[StructType] = None,
+ replace: Boolean = true,
+ provider: String = "parquet",
+ options: Map[String, String] = Map.empty): LogicalPlan = {
+ CreateTempViewUsing(TableIdentifier(table, database), schema, replace, provider, options)
+ }
+
+ private def createTable(
+ table: String,
+ database: Option[String] = None,
+ tableType: CatalogTableType = CatalogTableType.MANAGED,
+ storage: CatalogStorageFormat = CatalogStorageFormat.empty.copy(
+ inputFormat = HiveSerDe.sourceToSerDe("textfile").get.inputFormat,
+ outputFormat = HiveSerDe.sourceToSerDe("textfile").get.outputFormat),
+ schema: StructType = new StructType,
+ provider: Option[String] = Some("hive"),
+ partitionColumnNames: Seq[String] = Seq.empty,
+ comment: Option[String] = None,
+ mode: SaveMode = SaveMode.ErrorIfExists,
+ query: Option[LogicalPlan] = None): CreateTable = {
+ CreateTable(
+ CatalogTable(
+ identifier = TableIdentifier(table, database),
+ tableType = tableType,
+ storage = storage,
+ schema = schema,
+ provider = provider,
+ partitionColumnNames = partitionColumnNames,
+ comment = comment
+ ), mode, query
+ )
+ }
+
+ test("create table - schema") {
+ assertEqual("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING)",
+ createTable(
+ table = "my_tab",
+ schema = (new StructType)
+ .add("a", IntegerType, nullable = true, "test")
+ .add("b", StringType)
+ )
+ )
+ assertEqual("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) " +
+ "PARTITIONED BY (c INT, d STRING COMMENT 'test2')",
+ createTable(
+ table = "my_tab",
+ schema = (new StructType)
+ .add("a", IntegerType, nullable = true, "test")
+ .add("b", StringType)
+ .add("c", IntegerType)
+ .add("d", StringType, nullable = true, "test2"),
+ partitionColumnNames = Seq("c", "d")
+ )
+ )
+ assertEqual("CREATE TABLE my_tab(id BIGINT, nested STRUCT<col1: STRING,col2: INT>)",
+ createTable(
+ table = "my_tab",
+ schema = (new StructType)
+ .add("id", LongType)
+ .add("nested", (new StructType)
+ .add("col1", StringType)
+ .add("col2", IntegerType)
+ )
+ )
+ )
+ // Partitioned by a StructType should be accepted by `SparkSqlParser` but will fail an analyze
+ // rule in `AnalyzeCreateTable`.
+ assertEqual("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) " +
+ "PARTITIONED BY (nested STRUCT<col1: STRING,col2: INT>)",
+ createTable(
+ table = "my_tab",
+ schema = (new StructType)
+ .add("a", IntegerType, nullable = true, "test")
+ .add("b", StringType)
+ .add("nested", (new StructType)
+ .add("col1", StringType)
+ .add("col2", IntegerType)
+ ),
+ partitionColumnNames = Seq("nested")
+ )
+ )
+ intercept("CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING)",
+ "no viable alternative at input")
+ }
+
+ test("create table using - schema") {
+ assertEqual("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) USING parquet",
+ createTableUsing(
+ table = "my_tab",
+ schema = (new StructType)
+ .add("a", IntegerType, nullable = true, "test")
+ .add("b", StringType)
+ )
+ )
+ intercept("CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING) USING parquet",
+ "no viable alternative at input")
}
test("SPARK-17328 Fix NPE with EXPLAIN DESCRIBE TABLE") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
index b5499f2884..1bcb810a15 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
@@ -642,7 +642,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
val csvFile =
Thread.currentThread().getContextClassLoader.getResource("test-data/cars.csv").toString
withView("testview") {
- sql(s"CREATE OR REPLACE TEMPORARY VIEW testview (c1: String, c2: String) USING " +
+ sql(s"CREATE OR REPLACE TEMPORARY VIEW testview (c1 String, c2 String) USING " +
"org.apache.spark.sql.execution.datasources.csv.CSVFileFormat " +
s"OPTIONS (PATH '$csvFile')")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala
index 54e27b6f73..9ce3338647 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala
@@ -243,7 +243,7 @@ class HiveDDLCommandSuite extends PlanTest {
.asInstanceOf[ScriptTransformation].copy(ioschema = null)
val plan2 = parser.parsePlan("map a, b using 'func' as c, d from e")
.asInstanceOf[ScriptTransformation].copy(ioschema = null)
- val plan3 = parser.parsePlan("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e")
+ val plan3 = parser.parsePlan("reduce a, b using 'func' as (c int, d decimal(10, 0)) from e")
.asInstanceOf[ScriptTransformation].copy(ioschema = null)
val p = ScriptTransformation(