aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@databricks.com>2017-02-23 10:25:18 -0800
committerWenchen Fan <wenchen@databricks.com>2017-02-23 10:25:18 -0800
commit78eae7e67fd5dec0c2d5b18000053ce86cd0f1ae (patch)
treece66255c5a02be3f56c4ced787d960198e0b1f0f /sql/catalyst
parent93aa4271596a30752dc5234d869c3ae2f6e8e723 (diff)
downloadspark-78eae7e67fd5dec0c2d5b18000053ce86cd0f1ae.tar.gz
spark-78eae7e67fd5dec0c2d5b18000053ce86cd0f1ae.tar.bz2
spark-78eae7e67fd5dec0c2d5b18000053ce86cd0f1ae.zip
[SPARK-19459] Support for nested char/varchar fields in ORC
## What changes were proposed in this pull request? This PR is a small follow-up on https://github.com/apache/spark/pull/16804. This PR also adds support for nested char/varchar fields in orc. ## How was this patch tested? I have added a regression test to the OrcSourceSuite. Author: Herman van Hovell <hvanhovell@databricks.com> Closes #17030 from hvanhovell/SPARK-19459-follow-up.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala34
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala73
2 files changed, 92 insertions, 15 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 926a37b363..d2e091f4dd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -76,7 +76,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
}
override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) {
- visit(ctx.dataType).asInstanceOf[DataType]
+ visitSparkDataType(ctx.dataType)
}
/* ********************************************************************************************
@@ -1006,7 +1006,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
* Create a [[Cast]] expression.
*/
override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) {
- Cast(expression(ctx.expression), typedVisit(ctx.dataType))
+ Cast(expression(ctx.expression), visitSparkDataType(ctx.dataType))
}
/**
@@ -1425,6 +1425,13 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
* DataType parsing
* ******************************************************************************************** */
/**
+ * Create a Spark DataType.
+ */
+ private def visitSparkDataType(ctx: DataTypeContext): DataType = {
+ HiveStringType.replaceCharType(typedVisit(ctx))
+ }
+
+ /**
* Resolve/create a primitive type.
*/
override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) {
@@ -1438,8 +1445,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
case ("double", Nil) => DoubleType
case ("date", Nil) => DateType
case ("timestamp", Nil) => TimestampType
- case ("char" | "varchar" | "string", Nil) => StringType
- case ("char" | "varchar", _ :: Nil) => StringType
+ case ("string", Nil) => StringType
+ case ("char", length :: Nil) => CharType(length.getText.toInt)
+ case ("varchar", length :: Nil) => VarcharType(length.getText.toInt)
case ("binary", Nil) => BinaryType
case ("decimal", Nil) => DecimalType.USER_DEFAULT
case ("decimal", precision :: Nil) => DecimalType(precision.getText.toInt, 0)
@@ -1461,7 +1469,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
case SqlBaseParser.MAP =>
MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1)))
case SqlBaseParser.STRUCT =>
- createStructType(ctx.complexColTypeList())
+ StructType(Option(ctx.complexColTypeList).toSeq.flatMap(visitComplexColTypeList))
}
}
@@ -1480,7 +1488,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
}
/**
- * Create a [[StructField]] from a column definition.
+ * Create a top level [[StructField]] from a column definition.
*/
override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) {
import ctx._
@@ -1491,19 +1499,15 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
builder.putString("comment", string(STRING))
}
// Add Hive type string to metadata.
- dataType match {
- case p: PrimitiveDataTypeContext =>
- p.identifier.getText.toLowerCase match {
- case "varchar" | "char" =>
- builder.putString(HIVE_TYPE_STRING, dataType.getText.toLowerCase)
- case _ =>
- }
- case _ =>
+ val rawDataType = typedVisit[DataType](ctx.dataType)
+ val cleanedDataType = HiveStringType.replaceCharType(rawDataType)
+ if (rawDataType != cleanedDataType) {
+ builder.putString(HIVE_TYPE_STRING, rawDataType.catalogString)
}
StructField(
identifier.getText,
- typedVisit(dataType),
+ cleanedDataType,
nullable = true,
builder.build())
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala
new file mode 100644
index 0000000000..b319eb70bc
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.types
+
+import scala.math.Ordering
+import scala.reflect.runtime.universe.typeTag
+
+import org.apache.spark.sql.catalyst.ScalaReflectionLock
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A hive string type for compatibility. These datatypes should only used for parsing,
+ * and should NOT be used anywhere else. Any instance of these data types should be
+ * replaced by a [[StringType]] before analysis.
+ */
+sealed abstract class HiveStringType extends AtomicType {
+ private[sql] type InternalType = UTF8String
+
+ private[sql] val ordering = implicitly[Ordering[InternalType]]
+
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized {
+ typeTag[InternalType]
+ }
+
+ override def defaultSize: Int = length
+
+ private[spark] override def asNullable: HiveStringType = this
+
+ def length: Int
+}
+
+object HiveStringType {
+ def replaceCharType(dt: DataType): DataType = dt match {
+ case ArrayType(et, nullable) =>
+ ArrayType(replaceCharType(et), nullable)
+ case MapType(kt, vt, nullable) =>
+ MapType(replaceCharType(kt), replaceCharType(vt), nullable)
+ case StructType(fields) =>
+ StructType(fields.map { field =>
+ field.copy(dataType = replaceCharType(field.dataType))
+ })
+ case _: HiveStringType => StringType
+ case _ => dt
+ }
+}
+
+/**
+ * Hive char type.
+ */
+case class CharType(length: Int) extends HiveStringType {
+ override def simpleString: String = s"char($length)"
+}
+
+/**
+ * Hive varchar type.
+ */
+case class VarcharType(length: Int) extends HiveStringType {
+ override def simpleString: String = s"varchar($length)"
+}