aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala34
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala73
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala12
3 files changed, 100 insertions, 19 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 926a37b363..d2e091f4dd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -76,7 +76,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
}
override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) {
- visit(ctx.dataType).asInstanceOf[DataType]
+ visitSparkDataType(ctx.dataType)
}
/* ********************************************************************************************
@@ -1006,7 +1006,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
* Create a [[Cast]] expression.
*/
override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) {
- Cast(expression(ctx.expression), typedVisit(ctx.dataType))
+ Cast(expression(ctx.expression), visitSparkDataType(ctx.dataType))
}
/**
@@ -1425,6 +1425,13 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
* DataType parsing
* ******************************************************************************************** */
/**
+ * Create a Spark DataType.
+ */
+ private def visitSparkDataType(ctx: DataTypeContext): DataType = {
+ HiveStringType.replaceCharType(typedVisit(ctx))
+ }
+
+ /**
* Resolve/create a primitive type.
*/
override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) {
@@ -1438,8 +1445,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
case ("double", Nil) => DoubleType
case ("date", Nil) => DateType
case ("timestamp", Nil) => TimestampType
- case ("char" | "varchar" | "string", Nil) => StringType
- case ("char" | "varchar", _ :: Nil) => StringType
+ case ("string", Nil) => StringType
+ case ("char", length :: Nil) => CharType(length.getText.toInt)
+ case ("varchar", length :: Nil) => VarcharType(length.getText.toInt)
case ("binary", Nil) => BinaryType
case ("decimal", Nil) => DecimalType.USER_DEFAULT
case ("decimal", precision :: Nil) => DecimalType(precision.getText.toInt, 0)
@@ -1461,7 +1469,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
case SqlBaseParser.MAP =>
MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1)))
case SqlBaseParser.STRUCT =>
- createStructType(ctx.complexColTypeList())
+ StructType(Option(ctx.complexColTypeList).toSeq.flatMap(visitComplexColTypeList))
}
}
@@ -1480,7 +1488,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
}
/**
- * Create a [[StructField]] from a column definition.
+ * Create a top level [[StructField]] from a column definition.
*/
override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) {
import ctx._
@@ -1491,19 +1499,15 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
builder.putString("comment", string(STRING))
}
// Add Hive type string to metadata.
- dataType match {
- case p: PrimitiveDataTypeContext =>
- p.identifier.getText.toLowerCase match {
- case "varchar" | "char" =>
- builder.putString(HIVE_TYPE_STRING, dataType.getText.toLowerCase)
- case _ =>
- }
- case _ =>
+ val rawDataType = typedVisit[DataType](ctx.dataType)
+ val cleanedDataType = HiveStringType.replaceCharType(rawDataType)
+ if (rawDataType != cleanedDataType) {
+ builder.putString(HIVE_TYPE_STRING, rawDataType.catalogString)
}
StructField(
identifier.getText,
- typedVisit(dataType),
+ cleanedDataType,
nullable = true,
builder.build())
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala
new file mode 100644
index 0000000000..b319eb70bc
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.types
+
+import scala.math.Ordering
+import scala.reflect.runtime.universe.typeTag
+
+import org.apache.spark.sql.catalyst.ScalaReflectionLock
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A hive string type for compatibility. These datatypes should only used for parsing,
+ * and should NOT be used anywhere else. Any instance of these data types should be
+ * replaced by a [[StringType]] before analysis.
+ */
+sealed abstract class HiveStringType extends AtomicType {
+ private[sql] type InternalType = UTF8String
+
+ private[sql] val ordering = implicitly[Ordering[InternalType]]
+
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized {
+ typeTag[InternalType]
+ }
+
+ override def defaultSize: Int = length
+
+ private[spark] override def asNullable: HiveStringType = this
+
+ def length: Int
+}
+
+object HiveStringType {
+ def replaceCharType(dt: DataType): DataType = dt match {
+ case ArrayType(et, nullable) =>
+ ArrayType(replaceCharType(et), nullable)
+ case MapType(kt, vt, nullable) =>
+ MapType(replaceCharType(kt), replaceCharType(vt), nullable)
+ case StructType(fields) =>
+ StructType(fields.map { field =>
+ field.copy(dataType = replaceCharType(field.dataType))
+ })
+ case _: HiveStringType => StringType
+ case _ => dt
+ }
+}
+
+/**
+ * Hive char type.
+ */
+case class CharType(length: Int) extends HiveStringType {
+ override def simpleString: String = s"char($length)"
+}
+
+/**
+ * Hive varchar type.
+ */
+case class VarcharType(length: Int) extends HiveStringType {
+ override def simpleString: String = s"varchar($length)"
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
index 59ea8916ef..11dda5425c 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
@@ -162,13 +162,16 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA
|CREATE EXTERNAL TABLE hive_orc(
| a STRING,
| b CHAR(10),
- | c VARCHAR(10))
+ | c VARCHAR(10),
+ | d ARRAY<CHAR(3)>)
|STORED AS orc""".stripMargin)
// Hive throws an exception if I assign the location in the create table statement.
hiveClient.runSqlHive(
s"ALTER TABLE hive_orc SET LOCATION '$uri'")
hiveClient.runSqlHive(
- "INSERT INTO TABLE hive_orc SELECT 'a', 'b', 'c' FROM (SELECT 1) t")
+ """INSERT INTO TABLE hive_orc
+ |SELECT 'a', 'b', 'c', ARRAY(CAST('d' AS CHAR(3)))
+ |FROM (SELECT 1) t""".stripMargin)
// We create a different table in Spark using the same schema which points to
// the same location.
@@ -177,10 +180,11 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA
|CREATE EXTERNAL TABLE spark_orc(
| a STRING,
| b CHAR(10),
- | c VARCHAR(10))
+ | c VARCHAR(10),
+ | d ARRAY<CHAR(3)>)
|STORED AS orc
|LOCATION '$uri'""".stripMargin)
- val result = Row("a", "b ", "c")
+ val result = Row("a", "b ", "c", Seq("d "))
checkAnswer(spark.table("hive_orc"), result)
checkAnswer(spark.table("spark_orc"), result)
} finally {