aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorscwf <wangfei1@huawei.com>2015-01-10 13:53:21 -0800
committerMichael Armbrust <michael@databricks.com>2015-01-10 13:53:21 -0800
commit693a323a70aba91e6c100dd5561d218a75b7895e (patch)
tree3604ec22163d5296496d1d1907e6b3edbafc0108 /sql
parent4b39fd1e63188821fc84a13f7ccb6e94277f4be7 (diff)
downloadspark-693a323a70aba91e6c100dd5561d218a75b7895e.tar.gz
spark-693a323a70aba91e6c100dd5561d218a75b7895e.tar.bz2
spark-693a323a70aba91e6c100dd5561d218a75b7895e.zip
[SPARK-4574][SQL] Adding support for defining schema in foreign DDL commands.
Adding support for defining schema in foreign DDL commands. Now foreign DDL support commands like: ``` CREATE TEMPORARY TABLE avroTable USING org.apache.spark.sql.avro OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro") ``` With this PR user can define schema instead of infer from file, so support ddl command as follows: ``` CREATE TEMPORARY TABLE avroTable(a int, b string) USING org.apache.spark.sql.avro OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro") ``` Author: scwf <wangfei1@huawei.com> Author: Yin Huai <yhuai@databricks.com> Author: Fei Wang <wangfei1@huawei.com> Author: wangfei <wangfei1@huawei.com> Closes #3431 from scwf/ddl and squashes the following commits: 7e79ce5 [Fei Wang] Merge pull request #22 from yhuai/pr3431yin 38f634e [Yin Huai] Remove Option from createRelation. 65e9c73 [Yin Huai] Revert all changes since applying a given schema has not been testd. a852b10 [scwf] remove cleanIdentifier f336a16 [Fei Wang] Merge pull request #21 from yhuai/pr3431yin baf79b5 [Yin Huai] Test special characters quoted by backticks. 50a03b0 [Yin Huai] Use JsonRDD.nullTypeToStringType to convert NullType to StringType. 1eeb769 [Fei Wang] Merge pull request #20 from yhuai/pr3431yin f5c22b0 [Yin Huai] Refactor code and update test cases. f1cffe4 [Yin Huai] Revert "minor refactory" b621c8f [scwf] minor refactory d02547f [scwf] fix HiveCompatibilitySuite test failure 8dfbf7a [scwf] more tests for complex data type ddab984 [Fei Wang] Merge pull request #19 from yhuai/pr3431yin 91ad91b [Yin Huai] Parse data types in DDLParser. cf982d2 [scwf] fixed test failure 445b57b [scwf] address comments 02a662c [scwf] style issue 44eb70c [scwf] fix decimal parser issue 83b6fc3 [scwf] minor fix 9bf12f8 [wangfei] adding test case 7787ec7 [wangfei] added SchemaRelationProvider 0ba70df [wangfei] draft version
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala35
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala138
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala29
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala192
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala114
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala5
6 files changed, 400 insertions, 113 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
index fc70c18343..a9a6696cb1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
@@ -18,31 +18,48 @@
package org.apache.spark.sql.json
import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.types.StructType
import org.apache.spark.sql.sources._
-private[sql] class DefaultSource extends RelationProvider {
- /** Returns a new base relation with the given parameters. */
+private[sql] class DefaultSource extends RelationProvider with SchemaRelationProvider {
+
+ /** Returns a new base relation with the parameters. */
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
- JSONRelation(fileName, samplingRatio)(sqlContext)
+ JSONRelation(fileName, samplingRatio, None)(sqlContext)
+ }
+
+ /** Returns a new base relation with the given schema and parameters. */
+ override def createRelation(
+ sqlContext: SQLContext,
+ parameters: Map[String, String],
+ schema: StructType): BaseRelation = {
+ val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
+ val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
+
+ JSONRelation(fileName, samplingRatio, Some(schema))(sqlContext)
}
}
-private[sql] case class JSONRelation(fileName: String, samplingRatio: Double)(
+private[sql] case class JSONRelation(
+ fileName: String,
+ samplingRatio: Double,
+ userSpecifiedSchema: Option[StructType])(
@transient val sqlContext: SQLContext)
extends TableScan {
private def baseRDD = sqlContext.sparkContext.textFile(fileName)
- override val schema =
- JsonRDD.inferSchema(
- baseRDD,
- samplingRatio,
- sqlContext.columnNameOfCorruptRecord)
+ override val schema = userSpecifiedSchema.getOrElse(
+ JsonRDD.nullTypeToStringType(
+ JsonRDD.inferSchema(
+ baseRDD,
+ samplingRatio,
+ sqlContext.columnNameOfCorruptRecord)))
override def buildScan() =
JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.columnNameOfCorruptRecord)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
index 364bacec83..fe2c4d8436 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
@@ -17,16 +17,15 @@
package org.apache.spark.sql.sources
-import org.apache.spark.Logging
-import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.execution.RunnableCommand
-import org.apache.spark.util.Utils
-
import scala.language.implicitConversions
-import scala.util.parsing.combinator.lexical.StdLexical
import scala.util.parsing.combinator.syntactical.StandardTokenParsers
import scala.util.parsing.combinator.PackratParsers
+import org.apache.spark.Logging
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.execution.RunnableCommand
+import org.apache.spark.util.Utils
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.SqlLexical
@@ -44,6 +43,14 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi
}
}
+ def parseType(input: String): DataType = {
+ phrase(dataType)(new lexical.Scanner(input)) match {
+ case Success(r, x) => r
+ case x =>
+ sys.error(s"Unsupported dataType: $x")
+ }
+ }
+
protected case class Keyword(str: String)
protected implicit def asParser(k: Keyword): Parser[String] =
@@ -55,6 +62,24 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi
protected val USING = Keyword("USING")
protected val OPTIONS = Keyword("OPTIONS")
+ // Data types.
+ protected val STRING = Keyword("STRING")
+ protected val BINARY = Keyword("BINARY")
+ protected val BOOLEAN = Keyword("BOOLEAN")
+ protected val TINYINT = Keyword("TINYINT")
+ protected val SMALLINT = Keyword("SMALLINT")
+ protected val INT = Keyword("INT")
+ protected val BIGINT = Keyword("BIGINT")
+ protected val FLOAT = Keyword("FLOAT")
+ protected val DOUBLE = Keyword("DOUBLE")
+ protected val DECIMAL = Keyword("DECIMAL")
+ protected val DATE = Keyword("DATE")
+ protected val TIMESTAMP = Keyword("TIMESTAMP")
+ protected val VARCHAR = Keyword("VARCHAR")
+ protected val ARRAY = Keyword("ARRAY")
+ protected val MAP = Keyword("MAP")
+ protected val STRUCT = Keyword("STRUCT")
+
// Use reflection to find the reserved words defined in this class.
protected val reservedWords =
this.getClass
@@ -67,15 +92,25 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi
protected lazy val ddl: Parser[LogicalPlan] = createTable
/**
- * CREATE TEMPORARY TABLE avroTable
+ * `CREATE TEMPORARY TABLE avroTable
* USING org.apache.spark.sql.avro
- * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")
+ * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")`
+ * or
+ * `CREATE TEMPORARY TABLE avroTable(intField int, stringField string...)
+ * USING org.apache.spark.sql.avro
+ * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")`
*/
protected lazy val createTable: Parser[LogicalPlan] =
- CREATE ~ TEMPORARY ~ TABLE ~> ident ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ {
- case tableName ~ provider ~ opts =>
- CreateTableUsing(tableName, provider, opts)
+ (
+ CREATE ~ TEMPORARY ~ TABLE ~> ident
+ ~ (tableCols).? ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ {
+ case tableName ~ columns ~ provider ~ opts =>
+ val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields)))
+ CreateTableUsing(tableName, userSpecifiedSchema, provider, opts)
}
+ )
+
+ protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")"
protected lazy val options: Parser[Map[String, String]] =
"(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap }
@@ -83,10 +118,66 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi
protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")}
protected lazy val pair: Parser[(String, String)] = ident ~ stringLit ^^ { case k ~ v => (k,v) }
+
+ protected lazy val column: Parser[StructField] =
+ ident ~ dataType ^^ { case columnName ~ typ =>
+ StructField(columnName, typ)
+ }
+
+ protected lazy val primitiveType: Parser[DataType] =
+ STRING ^^^ StringType |
+ BINARY ^^^ BinaryType |
+ BOOLEAN ^^^ BooleanType |
+ TINYINT ^^^ ByteType |
+ SMALLINT ^^^ ShortType |
+ INT ^^^ IntegerType |
+ BIGINT ^^^ LongType |
+ FLOAT ^^^ FloatType |
+ DOUBLE ^^^ DoubleType |
+ fixedDecimalType | // decimal with precision/scale
+ DECIMAL ^^^ DecimalType.Unlimited | // decimal with no precision/scale
+ DATE ^^^ DateType |
+ TIMESTAMP ^^^ TimestampType |
+ VARCHAR ~ "(" ~ numericLit ~ ")" ^^^ StringType
+
+ protected lazy val fixedDecimalType: Parser[DataType] =
+ (DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ {
+ case precision ~ scale => DecimalType(precision.toInt, scale.toInt)
+ }
+
+ protected lazy val arrayType: Parser[DataType] =
+ ARRAY ~> "<" ~> dataType <~ ">" ^^ {
+ case tpe => ArrayType(tpe)
+ }
+
+ protected lazy val mapType: Parser[DataType] =
+ MAP ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ {
+ case t1 ~ _ ~ t2 => MapType(t1, t2)
+ }
+
+ protected lazy val structField: Parser[StructField] =
+ ident ~ ":" ~ dataType ^^ {
+ case fieldName ~ _ ~ tpe => StructField(fieldName, tpe, nullable = true)
+ }
+
+ protected lazy val structType: Parser[DataType] =
+ (STRUCT ~> "<" ~> repsep(structField, ",") <~ ">" ^^ {
+ case fields => new StructType(fields)
+ }) |
+ (STRUCT ~> "<>" ^^ {
+ case fields => new StructType(Nil)
+ })
+
+ private[sql] lazy val dataType: Parser[DataType] =
+ arrayType |
+ mapType |
+ structType |
+ primitiveType
}
private[sql] case class CreateTableUsing(
tableName: String,
+ userSpecifiedSchema: Option[StructType],
provider: String,
options: Map[String, String]) extends RunnableCommand {
@@ -99,8 +190,29 @@ private[sql] case class CreateTableUsing(
sys.error(s"Failed to load class for data source: $provider")
}
}
- val dataSource = clazz.newInstance().asInstanceOf[org.apache.spark.sql.sources.RelationProvider]
- val relation = dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options))
+
+ val relation = userSpecifiedSchema match {
+ case Some(schema: StructType) => {
+ clazz.newInstance match {
+ case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
+ dataSource
+ .asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider]
+ .createRelation(sqlContext, new CaseInsensitiveMap(options), schema)
+ case _ =>
+ sys.error(s"${clazz.getCanonicalName} should extend SchemaRelationProvider.")
+ }
+ }
+ case None => {
+ clazz.newInstance match {
+ case dataSource: org.apache.spark.sql.sources.RelationProvider =>
+ dataSource
+ .asInstanceOf[org.apache.spark.sql.sources.RelationProvider]
+ .createRelation(sqlContext, new CaseInsensitiveMap(options))
+ case _ =>
+ sys.error(s"${clazz.getCanonicalName} should extend RelationProvider.")
+ }
+ }
+ }
sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName)
Seq.empty
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 02eff80456..990f7e0e74 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.sources
import org.apache.spark.annotation.{Experimental, DeveloperApi}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{SQLConf, Row, SQLContext, StructType}
+import org.apache.spark.sql.{Row, SQLContext, StructType}
import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute}
/**
@@ -46,6 +46,33 @@ trait RelationProvider {
/**
* ::DeveloperApi::
+ * Implemented by objects that produce relations for a specific kind of data source. When
+ * Spark SQL is given a DDL operation with
+ * 1. USING clause: to specify the implemented SchemaRelationProvider
+ * 2. User defined schema: users can define schema optionally when create table
+ *
+ * Users may specify the fully qualified class name of a given data source. When that class is
+ * not found Spark SQL will append the class name `DefaultSource` to the path, allowing for
+ * less verbose invocation. For example, 'org.apache.spark.sql.json' would resolve to the
+ * data source 'org.apache.spark.sql.json.DefaultSource'
+ *
+ * A new instance of this class with be instantiated each time a DDL call is made.
+ */
+@DeveloperApi
+trait SchemaRelationProvider {
+ /**
+ * Returns a new base relation with the given parameters and user defined schema.
+ * Note: the parameters' keywords are case insensitive and this insensitivity is enforced
+ * by the Map that is passed to the function.
+ */
+ def createRelation(
+ sqlContext: SQLContext,
+ parameters: Map[String, String],
+ schema: StructType): BaseRelation
+}
+
+/**
+ * ::DeveloperApi::
* Represents a collection of tuples with a known schema. Classes that extend BaseRelation must
* be able to produce the schema of their data in the form of a [[StructType]] Concrete
* implementation should inherit from one of the descendant `Scan` classes, which define various
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 3cd7b0115d..605190f5ae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -17,7 +17,10 @@
package org.apache.spark.sql.sources
+import java.sql.{Timestamp, Date}
+
import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.types.DecimalType
class DefaultSource extends SimpleScanSource
@@ -38,9 +41,77 @@ case class SimpleScan(from: Int, to: Int)(@transient val sqlContext: SQLContext)
override def buildScan() = sqlContext.sparkContext.parallelize(from to to).map(Row(_))
}
+class AllDataTypesScanSource extends SchemaRelationProvider {
+ override def createRelation(
+ sqlContext: SQLContext,
+ parameters: Map[String, String],
+ schema: StructType): BaseRelation = {
+ AllDataTypesScan(parameters("from").toInt, parameters("TO").toInt, schema)(sqlContext)
+ }
+}
+
+case class AllDataTypesScan(
+ from: Int,
+ to: Int,
+ userSpecifiedSchema: StructType)(@transient val sqlContext: SQLContext)
+ extends TableScan {
+
+ override def schema = userSpecifiedSchema
+
+ override def buildScan() = {
+ sqlContext.sparkContext.parallelize(from to to).map { i =>
+ Row(
+ s"str_$i",
+ s"str_$i".getBytes(),
+ i % 2 == 0,
+ i.toByte,
+ i.toShort,
+ i,
+ i.toLong,
+ i.toFloat,
+ i.toDouble,
+ BigDecimal(i),
+ BigDecimal(i),
+ new Date((i + 1) * 8640000),
+ new Timestamp(20000 + i),
+ s"varchar_$i",
+ Seq(i, i + 1),
+ Seq(Map(s"str_$i" -> Row(i.toLong))),
+ Map(i -> i.toString),
+ Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)),
+ Row(i, i.toString),
+ Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000)))))
+ }
+ }
+}
+
class TableScanSuite extends DataSourceTest {
import caseInsensisitiveContext._
+ var tableWithSchemaExpected = (1 to 10).map { i =>
+ Row(
+ s"str_$i",
+ s"str_$i",
+ i % 2 == 0,
+ i.toByte,
+ i.toShort,
+ i,
+ i.toLong,
+ i.toFloat,
+ i.toDouble,
+ BigDecimal(i),
+ BigDecimal(i),
+ new Date((i + 1) * 8640000),
+ new Timestamp(20000 + i),
+ s"varchar_$i",
+ Seq(i, i + 1),
+ Seq(Map(s"str_$i" -> Row(i.toLong))),
+ Map(i -> i.toString),
+ Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)),
+ Row(i, i.toString),
+ Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000)))))
+ }.toSeq
+
before {
sql(
"""
@@ -51,6 +122,37 @@ class TableScanSuite extends DataSourceTest {
| To '10'
|)
""".stripMargin)
+
+ sql(
+ """
+ |CREATE TEMPORARY TABLE tableWithSchema (
+ |`string$%Field` stRIng,
+ |binaryField binary,
+ |`booleanField` boolean,
+ |ByteField tinyint,
+ |shortField smaLlint,
+ |int_Field iNt,
+ |`longField_:,<>=+/~^` Bigint,
+ |floatField flOat,
+ |doubleField doubLE,
+ |decimalField1 decimal,
+ |decimalField2 decimal(9,2),
+ |dateField dAte,
+ |timestampField tiMestamp,
+ |varcharField varchaR(12),
+ |arrayFieldSimple Array<inT>,
+ |arrayFieldComplex Array<Map<String, Struct<key:bigInt>>>,
+ |mapFieldSimple MAP<iNt, StRing>,
+ |mapFieldComplex Map<Map<stRING, fLOAT>, Struct<key:bigInt>>,
+ |structFieldSimple StRuct<key:INt, Value:STrINg>,
+ |structFieldComplex StRuct<key:Array<String>, Value:struct<`value_(2)`:Array<date>>>
+ |)
+ |USING org.apache.spark.sql.sources.AllDataTypesScanSource
+ |OPTIONS (
+ | From '1',
+ | To '10'
+ |)
+ """.stripMargin)
}
sqlTest(
@@ -73,6 +175,96 @@ class TableScanSuite extends DataSourceTest {
"SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1",
(2 to 10).map(i => Row(i, i - 1)).toSeq)
+ test("Schema and all fields") {
+ val expectedSchema = StructType(
+ StructField("string$%Field", StringType, true) ::
+ StructField("binaryField", BinaryType, true) ::
+ StructField("booleanField", BooleanType, true) ::
+ StructField("ByteField", ByteType, true) ::
+ StructField("shortField", ShortType, true) ::
+ StructField("int_Field", IntegerType, true) ::
+ StructField("longField_:,<>=+/~^", LongType, true) ::
+ StructField("floatField", FloatType, true) ::
+ StructField("doubleField", DoubleType, true) ::
+ StructField("decimalField1", DecimalType.Unlimited, true) ::
+ StructField("decimalField2", DecimalType(9, 2), true) ::
+ StructField("dateField", DateType, true) ::
+ StructField("timestampField", TimestampType, true) ::
+ StructField("varcharField", StringType, true) ::
+ StructField("arrayFieldSimple", ArrayType(IntegerType), true) ::
+ StructField("arrayFieldComplex",
+ ArrayType(
+ MapType(StringType, StructType(StructField("key", LongType, true) :: Nil))), true) ::
+ StructField("mapFieldSimple", MapType(IntegerType, StringType), true) ::
+ StructField("mapFieldComplex",
+ MapType(
+ MapType(StringType, FloatType),
+ StructType(StructField("key", LongType, true) :: Nil)), true) ::
+ StructField("structFieldSimple",
+ StructType(
+ StructField("key", IntegerType, true) ::
+ StructField("Value", StringType, true) :: Nil), true) ::
+ StructField("structFieldComplex",
+ StructType(
+ StructField("key", ArrayType(StringType), true) ::
+ StructField("Value",
+ StructType(
+ StructField("value_(2)", ArrayType(DateType), true) :: Nil), true) :: Nil), true) ::
+ Nil
+ )
+
+ assert(expectedSchema == table("tableWithSchema").schema)
+
+ checkAnswer(
+ sql(
+ """SELECT
+ | `string$%Field`,
+ | cast(binaryField as string),
+ | booleanField,
+ | byteField,
+ | shortField,
+ | int_Field,
+ | `longField_:,<>=+/~^`,
+ | floatField,
+ | doubleField,
+ | decimalField1,
+ | decimalField2,
+ | dateField,
+ | timestampField,
+ | varcharField,
+ | arrayFieldSimple,
+ | arrayFieldComplex,
+ | mapFieldSimple,
+ | mapFieldComplex,
+ | structFieldSimple,
+ | structFieldComplex FROM tableWithSchema""".stripMargin),
+ tableWithSchemaExpected
+ )
+ }
+
+ sqlTest(
+ "SELECT count(*) FROM tableWithSchema",
+ 10)
+
+ sqlTest(
+ "SELECT `string$%Field` FROM tableWithSchema",
+ (1 to 10).map(i => Row(s"str_$i")).toSeq)
+
+ sqlTest(
+ "SELECT int_Field FROM tableWithSchema WHERE int_Field < 5",
+ (1 to 4).map(Row(_)).toSeq)
+
+ sqlTest(
+ "SELECT `longField_:,<>=+/~^` * 2 FROM tableWithSchema",
+ (1 to 10).map(i => Row(i * 2.toLong)).toSeq)
+
+ sqlTest(
+ "SELECT structFieldSimple.key, arrayFieldSimple[1] FROM tableWithSchema a where int_Field=1",
+ Seq(Seq(1, 2)))
+
+ sqlTest(
+ "SELECT structFieldComplex.Value.`value_(2)` FROM tableWithSchema",
+ (1 to 10).map(i => Row(Seq(new Date((i + 2) * 8640000)))).toSeq)
test("Caching") {
// Cached Query Execution
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 2c859894cf..c25288e000 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -20,12 +20,7 @@ package org.apache.spark.sql.hive
import java.io.IOException
import java.util.{List => JList}
-import org.apache.spark.sql.execution.SparkPlan
-
-import scala.util.parsing.combinator.RegexParsers
-
import org.apache.hadoop.util.ReflectionUtils
-
import org.apache.hadoop.hive.metastore.TableType
import org.apache.hadoop.hive.metastore.api.FieldSchema
import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition}
@@ -37,7 +32,6 @@ import org.apache.hadoop.hive.serde2.{Deserializer, SerDeException}
import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
import org.apache.spark.Logging
-import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.analysis.{Catalog, OverrideCatalog}
import org.apache.spark.sql.catalyst.expressions._
@@ -412,88 +406,6 @@ private[hive] case class InsertIntoHiveTable(
}
}
-/**
- * :: DeveloperApi ::
- * Provides conversions between Spark SQL data types and Hive Metastore types.
- */
-@DeveloperApi
-object HiveMetastoreTypes extends RegexParsers {
- protected lazy val primitiveType: Parser[DataType] =
- "string" ^^^ StringType |
- "float" ^^^ FloatType |
- "int" ^^^ IntegerType |
- "tinyint" ^^^ ByteType |
- "smallint" ^^^ ShortType |
- "double" ^^^ DoubleType |
- "bigint" ^^^ LongType |
- "binary" ^^^ BinaryType |
- "boolean" ^^^ BooleanType |
- fixedDecimalType | // Hive 0.13+ decimal with precision/scale
- "decimal" ^^^ DecimalType.Unlimited | // Hive 0.12 decimal with no precision/scale
- "date" ^^^ DateType |
- "timestamp" ^^^ TimestampType |
- "varchar\\((\\d+)\\)".r ^^^ StringType
-
- protected lazy val fixedDecimalType: Parser[DataType] =
- ("decimal" ~> "(" ~> "\\d+".r) ~ ("," ~> "\\d+".r <~ ")") ^^ {
- case precision ~ scale =>
- DecimalType(precision.toInt, scale.toInt)
- }
-
- protected lazy val arrayType: Parser[DataType] =
- "array" ~> "<" ~> dataType <~ ">" ^^ {
- case tpe => ArrayType(tpe)
- }
-
- protected lazy val mapType: Parser[DataType] =
- "map" ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ {
- case t1 ~ _ ~ t2 => MapType(t1, t2)
- }
-
- protected lazy val structField: Parser[StructField] =
- "[a-zA-Z0-9_]*".r ~ ":" ~ dataType ^^ {
- case name ~ _ ~ tpe => StructField(name, tpe, nullable = true)
- }
-
- protected lazy val structType: Parser[DataType] =
- "struct" ~> "<" ~> repsep(structField,",") <~ ">" ^^ {
- case fields => new StructType(fields)
- }
-
- protected lazy val dataType: Parser[DataType] =
- arrayType |
- mapType |
- structType |
- primitiveType
-
- def toDataType(metastoreType: String): DataType = parseAll(dataType, metastoreType) match {
- case Success(result, _) => result
- case failure: NoSuccess => sys.error(s"Unsupported dataType: $metastoreType")
- }
-
- def toMetastoreType(dt: DataType): String = dt match {
- case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>"
- case StructType(fields) =>
- s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>"
- case MapType(keyType, valueType, _) =>
- s"map<${toMetastoreType(keyType)},${toMetastoreType(valueType)}>"
- case StringType => "string"
- case FloatType => "float"
- case IntegerType => "int"
- case ByteType => "tinyint"
- case ShortType => "smallint"
- case DoubleType => "double"
- case LongType => "bigint"
- case BinaryType => "binary"
- case BooleanType => "boolean"
- case DateType => "date"
- case d: DecimalType => HiveShim.decimalMetastoreString(d)
- case TimestampType => "timestamp"
- case NullType => "void"
- case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType)
- }
-}
-
private[hive] case class MetastoreRelation
(databaseName: String, tableName: String, alias: Option[String])
(val table: TTable, val partitions: Seq[TPartition])
@@ -551,7 +463,7 @@ private[hive] case class MetastoreRelation
implicit class SchemaAttribute(f: FieldSchema) {
def toAttribute = AttributeReference(
f.getName,
- HiveMetastoreTypes.toDataType(f.getType),
+ sqlContext.ddlParser.parseType(f.getType),
// Since data can be dumped in randomly with no validation, everything is nullable.
nullable = true
)(qualifiers = Seq(alias.getOrElse(tableName)))
@@ -571,3 +483,27 @@ private[hive] case class MetastoreRelation
/** An attribute map for determining the ordinal for non-partition columns. */
val columnOrdinals = AttributeMap(attributes.zipWithIndex)
}
+
+object HiveMetastoreTypes {
+ def toMetastoreType(dt: DataType): String = dt match {
+ case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>"
+ case StructType(fields) =>
+ s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>"
+ case MapType(keyType, valueType, _) =>
+ s"map<${toMetastoreType(keyType)},${toMetastoreType(valueType)}>"
+ case StringType => "string"
+ case FloatType => "float"
+ case IntegerType => "int"
+ case ByteType => "tinyint"
+ case ShortType => "smallint"
+ case DoubleType => "double"
+ case LongType => "bigint"
+ case BinaryType => "binary"
+ case BooleanType => "boolean"
+ case DateType => "date"
+ case d: DecimalType => HiveShim.decimalMetastoreString(d)
+ case TimestampType => "timestamp"
+ case NullType => "void"
+ case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType)
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
index 86535f8dd4..041a36f129 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.hive
import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.types.StructType
+import org.apache.spark.sql.sources.DDLParser
import org.apache.spark.sql.test.ExamplePointUDT
class HiveMetastoreCatalogSuite extends FunSuite {
@@ -27,7 +28,9 @@ class HiveMetastoreCatalogSuite extends FunSuite {
test("struct field should accept underscore in sub-column name") {
val metastr = "struct<a: int, b_1: string, c: string>"
- val datatype = HiveMetastoreTypes.toDataType(metastr)
+ val ddlParser = new DDLParser
+
+ val datatype = ddlParser.parseType(metastr)
assert(datatype.isInstanceOf[StructType])
}