aboutsummaryrefslogtreecommitdiff
path: root/sql/hive/src/main/scala/org/apache
diff options
context:
space:
mode:
Diffstat (limited to 'sql/hive/src/main/scala/org/apache')
-rw-r--r--sql/hive/src/main/scala/org/apache/hadoop/mapred/SparkHadoopWriter.scala198
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala287
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala246
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala966
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala164
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/ScriptTransformation.scala76
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala243
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala341
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala356
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala467
10 files changed, 3344 insertions, 0 deletions
diff --git a/sql/hive/src/main/scala/org/apache/hadoop/mapred/SparkHadoopWriter.scala b/sql/hive/src/main/scala/org/apache/hadoop/mapred/SparkHadoopWriter.scala
new file mode 100644
index 0000000000..08d390e887
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/hadoop/mapred/SparkHadoopWriter.scala
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mapred
+
+import java.io.IOException
+import java.text.NumberFormat
+import java.util.Date
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.Writable
+
+import org.apache.spark.Logging
+import org.apache.spark.SerializableWritable
+
+import org.apache.hadoop.hive.ql.exec.{Utilities, FileSinkOperator}
+import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat}
+import org.apache.hadoop.hive.ql.plan.FileSinkDesc
+
+/**
+ * Internal helper class that saves an RDD using a Hive OutputFormat.
+ * It is based on [[SparkHadoopWriter]].
+ */
+protected[apache]
+class SparkHiveHadoopWriter(
+ @transient jobConf: JobConf,
+ fileSinkConf: FileSinkDesc)
+ extends Logging
+ with SparkHadoopMapRedUtil
+ with Serializable {
+
+ private val now = new Date()
+ private val conf = new SerializableWritable(jobConf)
+
+ private var jobID = 0
+ private var splitID = 0
+ private var attemptID = 0
+ private var jID: SerializableWritable[JobID] = null
+ private var taID: SerializableWritable[TaskAttemptID] = null
+
+ @transient private var writer: FileSinkOperator.RecordWriter = null
+ @transient private var format: HiveOutputFormat[AnyRef, Writable] = null
+ @transient private var committer: OutputCommitter = null
+ @transient private var jobContext: JobContext = null
+ @transient private var taskContext: TaskAttemptContext = null
+
+ def preSetup() {
+ setIDs(0, 0, 0)
+ setConfParams()
+
+ val jCtxt = getJobContext()
+ getOutputCommitter().setupJob(jCtxt)
+ }
+
+
+ def setup(jobid: Int, splitid: Int, attemptid: Int) {
+ setIDs(jobid, splitid, attemptid)
+ setConfParams()
+ }
+
+ def open() {
+ val numfmt = NumberFormat.getInstance()
+ numfmt.setMinimumIntegerDigits(5)
+ numfmt.setGroupingUsed(false)
+
+ val extension = Utilities.getFileExtension(
+ conf.value,
+ fileSinkConf.getCompressed,
+ getOutputFormat())
+
+ val outputName = "part-" + numfmt.format(splitID) + extension
+ val path = FileOutputFormat.getTaskOutputPath(conf.value, outputName)
+
+ getOutputCommitter().setupTask(getTaskContext())
+ writer = HiveFileFormatUtils.getHiveRecordWriter(
+ conf.value,
+ fileSinkConf.getTableInfo,
+ conf.value.getOutputValueClass.asInstanceOf[Class[Writable]],
+ fileSinkConf,
+ path,
+ null)
+ }
+
+ def write(value: Writable) {
+ if (writer != null) {
+ writer.write(value)
+ } else {
+ throw new IOException("Writer is null, open() has not been called")
+ }
+ }
+
+ def close() {
+ // Seems the boolean value passed into close does not matter.
+ writer.close(false)
+ }
+
+ def commit() {
+ val taCtxt = getTaskContext()
+ val cmtr = getOutputCommitter()
+ if (cmtr.needsTaskCommit(taCtxt)) {
+ try {
+ cmtr.commitTask(taCtxt)
+ logInfo (taID + ": Committed")
+ } catch {
+ case e: IOException => {
+ logError("Error committing the output of task: " + taID.value, e)
+ cmtr.abortTask(taCtxt)
+ throw e
+ }
+ }
+ } else {
+ logWarning ("No need to commit output of task: " + taID.value)
+ }
+ }
+
+ def commitJob() {
+ // always ? Or if cmtr.needsTaskCommit ?
+ val cmtr = getOutputCommitter()
+ cmtr.commitJob(getJobContext())
+ }
+
+ // ********* Private Functions *********
+
+ private def getOutputFormat(): HiveOutputFormat[AnyRef,Writable] = {
+ if (format == null) {
+ format = conf.value.getOutputFormat()
+ .asInstanceOf[HiveOutputFormat[AnyRef,Writable]]
+ }
+ format
+ }
+
+ private def getOutputCommitter(): OutputCommitter = {
+ if (committer == null) {
+ committer = conf.value.getOutputCommitter
+ }
+ committer
+ }
+
+ private def getJobContext(): JobContext = {
+ if (jobContext == null) {
+ jobContext = newJobContext(conf.value, jID.value)
+ }
+ jobContext
+ }
+
+ private def getTaskContext(): TaskAttemptContext = {
+ if (taskContext == null) {
+ taskContext = newTaskAttemptContext(conf.value, taID.value)
+ }
+ taskContext
+ }
+
+ private def setIDs(jobid: Int, splitid: Int, attemptid: Int) {
+ jobID = jobid
+ splitID = splitid
+ attemptID = attemptid
+
+ jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobid))
+ taID = new SerializableWritable[TaskAttemptID](
+ new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID))
+ }
+
+ private def setConfParams() {
+ conf.value.set("mapred.job.id", jID.value.toString)
+ conf.value.set("mapred.tip.id", taID.value.getTaskID.toString)
+ conf.value.set("mapred.task.id", taID.value.toString)
+ conf.value.setBoolean("mapred.task.is.map", true)
+ conf.value.setInt("mapred.task.partition", splitID)
+ }
+}
+
+object SparkHiveHadoopWriter {
+ def createPathFromString(path: String, conf: JobConf): Path = {
+ if (path == null) {
+ throw new IllegalArgumentException("Output path is null")
+ }
+ val outputPath = new Path(path)
+ val fs = outputPath.getFileSystem(conf)
+ if (outputPath == null || fs == null) {
+ throw new IllegalArgumentException("Incorrectly formatted output path")
+ }
+ outputPath.makeQualified(fs)
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
new file mode 100644
index 0000000000..4aad876cc0
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -0,0 +1,287 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+
+import java.io.{PrintStream, InputStreamReader, BufferedReader, File}
+import java.util.{ArrayList => JArrayList}
+import scala.language.implicitConversions
+
+import org.apache.spark.SparkContext
+import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.ql.session.SessionState
+import org.apache.hadoop.hive.ql.processors.{CommandProcessorResponse, CommandProcessorFactory}
+import org.apache.hadoop.hive.ql.processors.CommandProcessor
+import org.apache.hadoop.hive.ql.Driver
+import org.apache.spark.rdd.RDD
+
+import catalyst.analysis.{Analyzer, OverrideCatalog}
+import catalyst.expressions.GenericRow
+import catalyst.plans.logical.{BaseRelation, LogicalPlan, NativeCommand, ExplainCommand}
+import catalyst.types._
+
+import org.apache.spark.sql.execution._
+
+import scala.collection.JavaConversions._
+
+/**
+ * Starts up an instance of hive where metadata is stored locally. An in-process metadata data is
+ * created with data stored in ./metadata. Warehouse data is stored in in ./warehouse.
+ */
+class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) {
+
+ lazy val metastorePath = new File("metastore").getCanonicalPath
+ lazy val warehousePath: String = new File("warehouse").getCanonicalPath
+
+ /** Sets up the system initially or after a RESET command */
+ protected def configure() {
+ // TODO: refactor this so we can work with other databases.
+ runSqlHive(
+ s"set javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$metastorePath;create=true")
+ runSqlHive("set hive.metastore.warehouse.dir=" + warehousePath)
+ }
+
+ configure() // Must be called before initializing the catalog below.
+}
+
+/**
+ * An instance of the Spark SQL execution engine that integrates with data stored in Hive.
+ * Configuration for Hive is read from hive-site.xml on the classpath.
+ */
+class HiveContext(sc: SparkContext) extends SQLContext(sc) {
+ self =>
+
+ override def parseSql(sql: String): LogicalPlan = HiveQl.parseSql(sql)
+ override def executePlan(plan: LogicalPlan): this.QueryExecution =
+ new this.QueryExecution { val logical = plan }
+
+ // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur.
+ @transient
+ protected val outputBuffer = new java.io.OutputStream {
+ var pos: Int = 0
+ var buffer = new Array[Int](10240)
+ def write(i: Int): Unit = {
+ buffer(pos) = i
+ pos = (pos + 1) % buffer.size
+ }
+
+ override def toString = {
+ val (end, start) = buffer.splitAt(pos)
+ val input = new java.io.InputStream {
+ val iterator = (start ++ end).iterator
+
+ def read(): Int = if (iterator.hasNext) iterator.next else -1
+ }
+ val reader = new BufferedReader(new InputStreamReader(input))
+ val stringBuilder = new StringBuilder
+ var line = reader.readLine()
+ while(line != null) {
+ stringBuilder.append(line)
+ stringBuilder.append("\n")
+ line = reader.readLine()
+ }
+ stringBuilder.toString()
+ }
+ }
+
+ @transient protected[hive] lazy val hiveconf = new HiveConf(classOf[SessionState])
+ @transient protected[hive] lazy val sessionState = new SessionState(hiveconf)
+
+ sessionState.err = new PrintStream(outputBuffer, true, "UTF-8")
+ sessionState.out = new PrintStream(outputBuffer, true, "UTF-8")
+
+ /* A catalyst metadata catalog that points to the Hive Metastore. */
+ @transient
+ override lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog
+
+ /* An analyzer that uses the Hive metastore. */
+ @transient
+ override lazy val analyzer = new Analyzer(catalog, HiveFunctionRegistry, caseSensitive = false)
+
+ def tables: Seq[BaseRelation] = {
+ // TODO: Move this functionallity to Catalog. Make client protected.
+ val allTables = catalog.client.getAllTables("default")
+ allTables.map(catalog.lookupRelation(None, _, None)).collect { case b: BaseRelation => b }
+ }
+
+ /**
+ * Runs the specified SQL query using Hive.
+ */
+ protected def runSqlHive(sql: String): Seq[String] = {
+ val maxResults = 100000
+ val results = runHive(sql, 100000)
+ // It is very confusing when you only get back some of the results...
+ if (results.size == maxResults) sys.error("RESULTS POSSIBLY TRUNCATED")
+ results
+ }
+
+ // TODO: Move this.
+
+ SessionState.start(sessionState)
+
+ /**
+ * Execute the command using Hive and return the results as a sequence. Each element
+ * in the sequence is one row.
+ */
+ protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = {
+ try {
+ val cmd_trimmed: String = cmd.trim()
+ val tokens: Array[String] = cmd_trimmed.split("\\s+")
+ val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim()
+ val proc: CommandProcessor = CommandProcessorFactory.get(tokens(0), hiveconf)
+
+ SessionState.start(sessionState)
+
+ if (proc.isInstanceOf[Driver]) {
+ val driver: Driver = proc.asInstanceOf[Driver]
+ driver.init()
+
+ val results = new JArrayList[String]
+ val response: CommandProcessorResponse = driver.run(cmd)
+ // Throw an exception if there is an error in query processing.
+ if (response.getResponseCode != 0) {
+ driver.destroy()
+ throw new QueryExecutionException(response.getErrorMessage)
+ }
+ driver.setMaxRows(maxRows)
+ driver.getResults(results)
+ driver.destroy()
+ results
+ } else {
+ sessionState.out.println(tokens(0) + " " + cmd_1)
+ Seq(proc.run(cmd_1).getResponseCode.toString)
+ }
+ } catch {
+ case e: Exception =>
+ logger.error(
+ s"""
+ |======================
+ |HIVE FAILURE OUTPUT
+ |======================
+ |${outputBuffer.toString}
+ |======================
+ |END HIVE FAILURE OUTPUT
+ |======================
+ """.stripMargin)
+ throw e
+ }
+ }
+
+ @transient
+ val hivePlanner = new SparkPlanner with HiveStrategies {
+ val hiveContext = self
+
+ override val strategies: Seq[Strategy] = Seq(
+ TopK,
+ ColumnPrunings,
+ PartitionPrunings,
+ HiveTableScans,
+ DataSinks,
+ Scripts,
+ PartialAggregation,
+ SparkEquiInnerJoin,
+ BasicOperators,
+ CartesianProduct,
+ BroadcastNestedLoopJoin
+ )
+ }
+
+ @transient
+ override val planner = hivePlanner
+
+ @transient
+ protected lazy val emptyResult =
+ sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1)
+
+ /** Extends QueryExecution with hive specific features. */
+ abstract class QueryExecution extends super.QueryExecution {
+ // TODO: Create mixin for the analyzer instead of overriding things here.
+ override lazy val optimizedPlan =
+ optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed)))
+
+ // TODO: We are loosing schema here.
+ override lazy val toRdd: RDD[Row] =
+ analyzed match {
+ case NativeCommand(cmd) =>
+ val output = runSqlHive(cmd)
+
+ if (output.size == 0) {
+ emptyResult
+ } else {
+ val asRows = output.map(r => new GenericRow(r.split("\t").asInstanceOf[Array[Any]]))
+ sparkContext.parallelize(asRows, 1)
+ }
+ case _ =>
+ executedPlan.execute.map(_.copy())
+ }
+
+ protected val primitiveTypes =
+ Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType,
+ ShortType, DecimalType)
+
+ protected def toHiveString(a: (Any, DataType)): String = a match {
+ case (struct: Row, StructType(fields)) =>
+ struct.zip(fields).map {
+ case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}"""
+ }.mkString("{", ",", "}")
+ case (seq: Seq[_], ArrayType(typ))=>
+ seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]")
+ case (map: Map[_,_], MapType(kType, vType)) =>
+ map.map {
+ case (key, value) =>
+ toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType))
+ }.toSeq.sorted.mkString("{", ",", "}")
+ case (null, _) => "NULL"
+ case (other, tpe) if primitiveTypes contains tpe => other.toString
+ }
+
+ /** Hive outputs fields of structs slightly differently than top level attributes. */
+ protected def toHiveStructString(a: (Any, DataType)): String = a match {
+ case (struct: Row, StructType(fields)) =>
+ struct.zip(fields).map {
+ case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}"""
+ }.mkString("{", ",", "}")
+ case (seq: Seq[_], ArrayType(typ))=>
+ seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]")
+ case (map: Map[_,_], MapType(kType, vType)) =>
+ map.map {
+ case (key, value) =>
+ toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType))
+ }.toSeq.sorted.mkString("{", ",", "}")
+ case (null, _) => "null"
+ case (s: String, StringType) => "\"" + s + "\""
+ case (other, tpe) if primitiveTypes contains tpe => other.toString
+ }
+
+ /**
+ * Returns the result as a hive compatible sequence of strings. For native commands, the
+ * execution is simply passed back to Hive.
+ */
+ def stringResult(): Seq[String] = analyzed match {
+ case NativeCommand(cmd) => runSqlHive(cmd)
+ case ExplainCommand(plan) => new QueryExecution { val logical = plan }.toString.split("\n")
+ case query =>
+ val result: Seq[Seq[Any]] = toRdd.collect().toSeq
+ // We need the types so we can output struct field names
+ val types = analyzed.output.map(_.dataType)
+ // Reformat to match hive tab delimited output.
+ val asString = result.map(_.zip(types).map(toHiveString)).map(_.mkString("\t")).toSeq
+ asString
+ }
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
new file mode 100644
index 0000000000..e4d50722ce
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+
+import scala.util.parsing.combinator.RegexParsers
+
+import org.apache.hadoop.hive.metastore.api.{FieldSchema, StorageDescriptor, SerDeInfo}
+import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition}
+import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table}
+import org.apache.hadoop.hive.ql.plan.TableDesc
+import org.apache.hadoop.hive.ql.session.SessionState
+import org.apache.hadoop.hive.serde2.Deserializer
+
+import catalyst.analysis.Catalog
+import catalyst.expressions._
+import catalyst.plans.logical
+import catalyst.plans.logical._
+import catalyst.rules._
+import catalyst.types._
+
+import scala.collection.JavaConversions._
+
+class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Logging {
+ import HiveMetastoreTypes._
+
+ val client = Hive.get(hive.hiveconf)
+
+ def lookupRelation(
+ db: Option[String],
+ tableName: String,
+ alias: Option[String]): LogicalPlan = {
+ val databaseName = db.getOrElse(hive.sessionState.getCurrentDatabase())
+ val table = client.getTable(databaseName, tableName)
+ val partitions: Seq[Partition] =
+ if (table.isPartitioned) {
+ client.getPartitions(table)
+ } else {
+ Nil
+ }
+
+ // Since HiveQL is case insensitive for table names we make them all lowercase.
+ MetastoreRelation(
+ databaseName.toLowerCase,
+ tableName.toLowerCase,
+ alias)(table.getTTable, partitions.map(part => part.getTPartition))
+ }
+
+ def createTable(databaseName: String, tableName: String, schema: Seq[Attribute]) {
+ val table = new Table(databaseName, tableName)
+ val hiveSchema =
+ schema.map(attr => new FieldSchema(attr.name, toMetastoreType(attr.dataType), ""))
+ table.setFields(hiveSchema)
+
+ val sd = new StorageDescriptor()
+ table.getTTable.setSd(sd)
+ sd.setCols(hiveSchema)
+
+ // TODO: THESE ARE ALL DEFAULTS, WE NEED TO PARSE / UNDERSTAND the output specs.
+ sd.setCompressed(false)
+ sd.setParameters(Map[String, String]())
+ sd.setInputFormat("org.apache.hadoop.mapred.TextInputFormat")
+ sd.setOutputFormat("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")
+ val serDeInfo = new SerDeInfo()
+ serDeInfo.setName(tableName)
+ serDeInfo.setSerializationLib("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")
+ serDeInfo.setParameters(Map[String, String]())
+ sd.setSerdeInfo(serDeInfo)
+ client.createTable(table)
+ }
+
+ /**
+ * Creates any tables required for query execution.
+ * For example, because of a CREATE TABLE X AS statement.
+ */
+ object CreateTables extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case InsertIntoCreatedTable(db, tableName, child) =>
+ val databaseName = db.getOrElse(SessionState.get.getCurrentDatabase())
+
+ createTable(databaseName, tableName, child.output)
+
+ InsertIntoTable(
+ lookupRelation(Some(databaseName), tableName, None).asInstanceOf[BaseRelation],
+ Map.empty,
+ child,
+ overwrite = false)
+ }
+ }
+
+ /**
+ * Casts input data to correct data types according to table definition before inserting into
+ * that table.
+ */
+ object PreInsertionCasts extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transform {
+ // Wait until children are resolved
+ case p: LogicalPlan if !p.childrenResolved => p
+
+ case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) =>
+ val childOutputDataTypes = child.output.map(_.dataType)
+ // Only check attributes, not partitionKeys since they are always strings.
+ // TODO: Fully support inserting into partitioned tables.
+ val tableOutputDataTypes = table.attributes.map(_.dataType)
+
+ if (childOutputDataTypes == tableOutputDataTypes) {
+ p
+ } else {
+ // Only do the casting when child output data types differ from table output data types.
+ val castedChildOutput = child.output.zip(table.output).map {
+ case (input, table) if input.dataType != table.dataType =>
+ Alias(Cast(input, table.dataType), input.name)()
+ case (input, _) => input
+ }
+
+ p.copy(child = logical.Project(castedChildOutput, child))
+ }
+ }
+ }
+
+ /**
+ * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore.
+ * For now, if this functionallity is desired mix in the in-memory [[OverrideCatalog]].
+ */
+ override def registerTable(
+ databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = ???
+}
+
+object HiveMetastoreTypes extends RegexParsers {
+ protected lazy val primitiveType: Parser[DataType] =
+ "string" ^^^ StringType |
+ "float" ^^^ FloatType |
+ "int" ^^^ IntegerType |
+ "tinyint" ^^^ ShortType |
+ "double" ^^^ DoubleType |
+ "bigint" ^^^ LongType |
+ "binary" ^^^ BinaryType |
+ "boolean" ^^^ BooleanType |
+ "decimal" ^^^ DecimalType |
+ "varchar\\((\\d+)\\)".r ^^^ StringType
+
+ protected lazy val arrayType: Parser[DataType] =
+ "array" ~> "<" ~> dataType <~ ">" ^^ ArrayType
+
+ protected lazy val mapType: Parser[DataType] =
+ "map" ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ {
+ case t1 ~ _ ~ t2 => MapType(t1, t2)
+ }
+
+ protected lazy val structField: Parser[StructField] =
+ "[a-zA-Z0-9]*".r ~ ":" ~ dataType ^^ {
+ case name ~ _ ~ tpe => StructField(name, tpe, nullable = true)
+ }
+
+ protected lazy val structType: Parser[DataType] =
+ "struct" ~> "<" ~> repsep(structField,",") <~ ">" ^^ StructType
+
+ protected lazy val dataType: Parser[DataType] =
+ arrayType |
+ mapType |
+ structType |
+ primitiveType
+
+ def toDataType(metastoreType: String): DataType = parseAll(dataType, metastoreType) match {
+ case Success(result, _) => result
+ case failure: NoSuccess => sys.error(s"Unsupported dataType: $metastoreType")
+ }
+
+ def toMetastoreType(dt: DataType): String = dt match {
+ case ArrayType(elementType) => s"array<${toMetastoreType(elementType)}>"
+ case StructType(fields) =>
+ s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>"
+ case MapType(keyType, valueType) =>
+ s"map<${toMetastoreType(keyType)},${toMetastoreType(valueType)}>"
+ case StringType => "string"
+ case FloatType => "float"
+ case IntegerType => "int"
+ case ShortType =>"tinyint"
+ case DoubleType => "double"
+ case LongType => "bigint"
+ case BinaryType => "binary"
+ case BooleanType => "boolean"
+ case DecimalType => "decimal"
+ }
+}
+
+case class MetastoreRelation(databaseName: String, tableName: String, alias: Option[String])
+ (val table: TTable, val partitions: Seq[TPartition])
+ extends BaseRelation {
+ // TODO: Can we use org.apache.hadoop.hive.ql.metadata.Table as the type of table and
+ // use org.apache.hadoop.hive.ql.metadata.Partition as the type of elements of partitions.
+ // Right now, using org.apache.hadoop.hive.ql.metadata.Table and
+ // org.apache.hadoop.hive.ql.metadata.Partition will cause a NotSerializableException
+ // which indicates the SerDe we used is not Serializable.
+
+ def hiveQlTable = new Table(table)
+
+ def hiveQlPartitions = partitions.map { p =>
+ new Partition(hiveQlTable, p)
+ }
+
+ override def isPartitioned = hiveQlTable.isPartitioned
+
+ val tableDesc = new TableDesc(
+ Class.forName(hiveQlTable.getSerializationLib).asInstanceOf[Class[Deserializer]],
+ hiveQlTable.getInputFormatClass,
+ // The class of table should be org.apache.hadoop.hive.ql.metadata.Table because
+ // getOutputFormatClass will use HiveFileFormatUtils.getOutputFormatSubstitute to
+ // substitute some output formats, e.g. substituting SequenceFileOutputFormat to
+ // HiveSequenceFileOutputFormat.
+ hiveQlTable.getOutputFormatClass,
+ hiveQlTable.getMetadata
+ )
+
+ implicit class SchemaAttribute(f: FieldSchema) {
+ def toAttribute = AttributeReference(
+ f.getName,
+ HiveMetastoreTypes.toDataType(f.getType),
+ // Since data can be dumped in randomly with no validation, everything is nullable.
+ nullable = true
+ )(qualifiers = tableName +: alias.toSeq)
+ }
+
+ // Must be a stable value since new attributes are born here.
+ val partitionKeys = hiveQlTable.getPartitionKeys.map(_.toAttribute)
+
+ /** Non-partitionKey attributes */
+ val attributes = table.getSd.getCols.map(_.toAttribute)
+
+ val output = attributes ++ partitionKeys
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
new file mode 100644
index 0000000000..4f33a293c3
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -0,0 +1,966 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+
+import scala.collection.JavaConversions._
+
+import org.apache.hadoop.hive.ql.lib.Node
+import org.apache.hadoop.hive.ql.parse._
+import org.apache.hadoop.hive.ql.plan.PlanUtils
+
+import catalyst.analysis._
+import catalyst.expressions._
+import catalyst.plans._
+import catalyst.plans.logical
+import catalyst.plans.logical._
+import catalyst.types._
+
+/**
+ * Used when we need to start parsing the AST before deciding that we are going to pass the command
+ * back for Hive to execute natively. Will be replaced with a native command that contains the
+ * cmd string.
+ */
+case object NativePlaceholder extends Command
+
+case class DfsCommand(cmd: String) extends Command
+
+case class ShellCommand(cmd: String) extends Command
+
+case class SourceCommand(filePath: String) extends Command
+
+case class AddJar(jarPath: String) extends Command
+
+case class AddFile(filePath: String) extends Command
+
+/** Provides a mapping from HiveQL statments to catalyst logical plans and expression trees. */
+object HiveQl {
+ protected val nativeCommands = Seq(
+ "TOK_DESCFUNCTION",
+ "TOK_DESCTABLE",
+ "TOK_DESCDATABASE",
+ "TOK_SHOW_TABLESTATUS",
+ "TOK_SHOWDATABASES",
+ "TOK_SHOWFUNCTIONS",
+ "TOK_SHOWINDEXES",
+ "TOK_SHOWINDEXES",
+ "TOK_SHOWPARTITIONS",
+ "TOK_SHOWTABLES",
+
+ "TOK_LOCKTABLE",
+ "TOK_SHOWLOCKS",
+ "TOK_UNLOCKTABLE",
+
+ "TOK_CREATEROLE",
+ "TOK_DROPROLE",
+ "TOK_GRANT",
+ "TOK_GRANT_ROLE",
+ "TOK_REVOKE",
+ "TOK_SHOW_GRANT",
+ "TOK_SHOW_ROLE_GRANT",
+
+ "TOK_CREATEFUNCTION",
+ "TOK_DROPFUNCTION",
+
+ "TOK_ANALYZE",
+ "TOK_ALTERDATABASE_PROPERTIES",
+ "TOK_ALTERINDEX_PROPERTIES",
+ "TOK_ALTERINDEX_REBUILD",
+ "TOK_ALTERTABLE_ADDCOLS",
+ "TOK_ALTERTABLE_ADDPARTS",
+ "TOK_ALTERTABLE_ALTERPARTS",
+ "TOK_ALTERTABLE_ARCHIVE",
+ "TOK_ALTERTABLE_CLUSTER_SORT",
+ "TOK_ALTERTABLE_DROPPARTS",
+ "TOK_ALTERTABLE_PARTITION",
+ "TOK_ALTERTABLE_PROPERTIES",
+ "TOK_ALTERTABLE_RENAME",
+ "TOK_ALTERTABLE_RENAMECOL",
+ "TOK_ALTERTABLE_REPLACECOLS",
+ "TOK_ALTERTABLE_SKEWED",
+ "TOK_ALTERTABLE_TOUCH",
+ "TOK_ALTERTABLE_UNARCHIVE",
+ "TOK_ANALYZE",
+ "TOK_CREATEDATABASE",
+ "TOK_CREATEFUNCTION",
+ "TOK_CREATEINDEX",
+ "TOK_DROPDATABASE",
+ "TOK_DROPINDEX",
+ "TOK_DROPTABLE",
+ "TOK_MSCK",
+
+ // TODO(marmbrus): Figure out how view are expanded by hive, as we might need to handle this.
+ "TOK_ALTERVIEW_ADDPARTS",
+ "TOK_ALTERVIEW_AS",
+ "TOK_ALTERVIEW_DROPPARTS",
+ "TOK_ALTERVIEW_PROPERTIES",
+ "TOK_ALTERVIEW_RENAME",
+ "TOK_CREATEVIEW",
+ "TOK_DROPVIEW",
+
+ "TOK_EXPORT",
+ "TOK_IMPORT",
+ "TOK_LOAD",
+
+ "TOK_SWITCHDATABASE"
+ )
+
+ /**
+ * A set of implicit transformations that allow Hive ASTNodes to be rewritten by transformations
+ * similar to [[catalyst.trees.TreeNode]].
+ *
+ * Note that this should be considered very experimental and is not indented as a replacement
+ * for TreeNode. Primarily it should be noted ASTNodes are not immutable and do not appear to
+ * have clean copy semantics. Therefore, users of this class should take care when
+ * copying/modifying trees that might be used elsewhere.
+ */
+ implicit class TransformableNode(n: ASTNode) {
+ /**
+ * Returns a copy of this node where `rule` has been recursively applied to it and all of its
+ * children. When `rule` does not apply to a given node it is left unchanged.
+ * @param rule the function use to transform this nodes children
+ */
+ def transform(rule: PartialFunction[ASTNode, ASTNode]): ASTNode = {
+ try {
+ val afterRule = rule.applyOrElse(n, identity[ASTNode])
+ afterRule.withChildren(
+ nilIfEmpty(afterRule.getChildren)
+ .asInstanceOf[Seq[ASTNode]]
+ .map(ast => Option(ast).map(_.transform(rule)).orNull))
+ } catch {
+ case e: Exception =>
+ println(dumpTree(n))
+ throw e
+ }
+ }
+
+ /**
+ * Returns a scala.Seq equivilent to [s] or Nil if [s] is null.
+ */
+ private def nilIfEmpty[A](s: java.util.List[A]): Seq[A] =
+ Option(s).map(_.toSeq).getOrElse(Nil)
+
+ /**
+ * Returns this ASTNode with the text changed to `newText``.
+ */
+ def withText(newText: String): ASTNode = {
+ n.token.asInstanceOf[org.antlr.runtime.CommonToken].setText(newText)
+ n
+ }
+
+ /**
+ * Returns this ASTNode with the children changed to `newChildren`.
+ */
+ def withChildren(newChildren: Seq[ASTNode]): ASTNode = {
+ (1 to n.getChildCount).foreach(_ => n.deleteChild(0))
+ n.addChildren(newChildren)
+ n
+ }
+
+ /**
+ * Throws an error if this is not equal to other.
+ *
+ * Right now this function only checks the name, type, text and children of the node
+ * for equality.
+ */
+ def checkEquals(other: ASTNode) {
+ def check(field: String, f: ASTNode => Any) = if (f(n) != f(other)) {
+ sys.error(s"$field does not match for trees. " +
+ s"'${f(n)}' != '${f(other)}' left: ${dumpTree(n)}, right: ${dumpTree(other)}")
+ }
+ check("name", _.getName)
+ check("type", _.getType)
+ check("text", _.getText)
+ check("numChildren", n => nilIfEmpty(n.getChildren).size)
+
+ val leftChildren = nilIfEmpty(n.getChildren).asInstanceOf[Seq[ASTNode]]
+ val rightChildren = nilIfEmpty(other.getChildren).asInstanceOf[Seq[ASTNode]]
+ leftChildren zip rightChildren foreach {
+ case (l,r) => l checkEquals r
+ }
+ }
+ }
+
+ class ParseException(sql: String, cause: Throwable)
+ extends Exception(s"Failed to parse: $sql", cause)
+
+ /**
+ * Returns the AST for the given SQL string.
+ */
+ def getAst(sql: String): ASTNode = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql))
+
+ /** Returns a LogicalPlan for a given HiveQL string. */
+ def parseSql(sql: String): LogicalPlan = {
+ try {
+ if (sql.toLowerCase.startsWith("set")) {
+ NativeCommand(sql)
+ } else if (sql.toLowerCase.startsWith("add jar")) {
+ AddJar(sql.drop(8))
+ } else if (sql.toLowerCase.startsWith("add file")) {
+ AddFile(sql.drop(9))
+ } else if (sql.startsWith("dfs")) {
+ DfsCommand(sql)
+ } else if (sql.startsWith("source")) {
+ SourceCommand(sql.split(" ").toSeq match { case Seq("source", filePath) => filePath })
+ } else if (sql.startsWith("!")) {
+ ShellCommand(sql.drop(1))
+ } else {
+ val tree = getAst(sql)
+
+ if (nativeCommands contains tree.getText) {
+ NativeCommand(sql)
+ } else {
+ nodeToPlan(tree) match {
+ case NativePlaceholder => NativeCommand(sql)
+ case other => other
+ }
+ }
+ }
+ } catch {
+ case e: Exception => throw new ParseException(sql, e)
+ }
+ }
+
+ def parseDdl(ddl: String): Seq[Attribute] = {
+ val tree =
+ try {
+ ParseUtils.findRootNonNullToken(
+ (new ParseDriver).parse(ddl, null /* no context required for parsing alone */))
+ } catch {
+ case pe: org.apache.hadoop.hive.ql.parse.ParseException =>
+ throw new RuntimeException(s"Failed to parse ddl: '$ddl'", pe)
+ }
+ assert(tree.asInstanceOf[ASTNode].getText == "TOK_CREATETABLE", "Only CREATE TABLE supported.")
+ val tableOps = tree.getChildren
+ val colList =
+ tableOps
+ .find(_.asInstanceOf[ASTNode].getText == "TOK_TABCOLLIST")
+ .getOrElse(sys.error("No columnList!")).getChildren
+
+ colList.map(nodeToAttribute)
+ }
+
+ /** Extractor for matching Hive's AST Tokens. */
+ object Token {
+ /** @return matches of the form (tokenName, children). */
+ def unapply(t: Any): Option[(String, Seq[ASTNode])] = t match {
+ case t: ASTNode =>
+ Some((t.getText,
+ Option(t.getChildren).map(_.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]]))
+ case _ => None
+ }
+ }
+
+ protected def getClauses(clauseNames: Seq[String], nodeList: Seq[ASTNode]): Seq[Option[Node]] = {
+ var remainingNodes = nodeList
+ val clauses = clauseNames.map { clauseName =>
+ val (matches, nonMatches) = remainingNodes.partition(_.getText.toUpperCase == clauseName)
+ remainingNodes = nonMatches ++ (if (matches.nonEmpty) matches.tail else Nil)
+ matches.headOption
+ }
+
+ assert(remainingNodes.isEmpty,
+ s"Unhandled clauses: ${remainingNodes.map(dumpTree(_)).mkString("\n")}")
+ clauses
+ }
+
+ def getClause(clauseName: String, nodeList: Seq[Node]) =
+ getClauseOption(clauseName, nodeList).getOrElse(sys.error(
+ s"Expected clause $clauseName missing from ${nodeList.map(dumpTree(_)).mkString("\n")}"))
+
+ def getClauseOption(clauseName: String, nodeList: Seq[Node]): Option[Node] = {
+ nodeList.filter { case ast: ASTNode => ast.getText == clauseName } match {
+ case Seq(oneMatch) => Some(oneMatch)
+ case Seq() => None
+ case _ => sys.error(s"Found multiple instances of clause $clauseName")
+ }
+ }
+
+ protected def nodeToAttribute(node: Node): Attribute = node match {
+ case Token("TOK_TABCOL", Token(colName, Nil) :: dataType :: Nil) =>
+ AttributeReference(colName, nodeToDataType(dataType), true)()
+
+ case a: ASTNode =>
+ throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ")
+ }
+
+ protected def nodeToDataType(node: Node): DataType = node match {
+ case Token("TOK_BIGINT", Nil) => IntegerType
+ case Token("TOK_INT", Nil) => IntegerType
+ case Token("TOK_TINYINT", Nil) => IntegerType
+ case Token("TOK_SMALLINT", Nil) => IntegerType
+ case Token("TOK_BOOLEAN", Nil) => BooleanType
+ case Token("TOK_STRING", Nil) => StringType
+ case Token("TOK_FLOAT", Nil) => FloatType
+ case Token("TOK_DOUBLE", Nil) => FloatType
+ case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType))
+ case Token("TOK_STRUCT",
+ Token("TOK_TABCOLLIST", fields) :: Nil) =>
+ StructType(fields.map(nodeToStructField))
+ case Token("TOK_MAP",
+ keyType ::
+ valueType :: Nil) =>
+ MapType(nodeToDataType(keyType), nodeToDataType(valueType))
+ case a: ASTNode =>
+ throw new NotImplementedError(s"No parse rules for DataType:\n ${dumpTree(a).toString} ")
+ }
+
+ protected def nodeToStructField(node: Node): StructField = node match {
+ case Token("TOK_TABCOL",
+ Token(fieldName, Nil) ::
+ dataType :: Nil) =>
+ StructField(fieldName, nodeToDataType(dataType), nullable = true)
+ case Token("TOK_TABCOL",
+ Token(fieldName, Nil) ::
+ dataType ::
+ _ /* comment */:: Nil) =>
+ StructField(fieldName, nodeToDataType(dataType), nullable = true)
+ case a: ASTNode =>
+ throw new NotImplementedError(s"No parse rules for StructField:\n ${dumpTree(a).toString} ")
+ }
+
+ protected def nameExpressions(exprs: Seq[Expression]): Seq[NamedExpression] = {
+ exprs.zipWithIndex.map {
+ case (ne: NamedExpression, _) => ne
+ case (e, i) => Alias(e, s"c_$i")()
+ }
+ }
+
+ protected def nodeToPlan(node: Node): LogicalPlan = node match {
+ // Just fake explain for any of the native commands.
+ case Token("TOK_EXPLAIN", explainArgs) if nativeCommands contains explainArgs.head.getText =>
+ NoRelation
+ case Token("TOK_EXPLAIN", explainArgs) =>
+ // Ignore FORMATTED if present.
+ val Some(query) :: _ :: _ :: Nil =
+ getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs)
+ // TODO: support EXTENDED?
+ ExplainCommand(nodeToPlan(query))
+
+ case Token("TOK_CREATETABLE", children)
+ if children.collect { case t@Token("TOK_QUERY", _) => t }.nonEmpty =>
+ // TODO: Parse other clauses.
+ // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL
+ val (
+ Some(tableNameParts) ::
+ _ /* likeTable */ ::
+ Some(query) +:
+ notImplemented) =
+ getClauses(
+ Seq(
+ "TOK_TABNAME",
+ "TOK_LIKETABLE",
+ "TOK_QUERY",
+ "TOK_IFNOTEXISTS",
+ "TOK_TABLECOMMENT",
+ "TOK_TABCOLLIST",
+ "TOK_TABLEPARTCOLS", // Partitioned by
+ "TOK_TABLEBUCKETS", // Clustered by
+ "TOK_TABLESKEWED", // Skewed by
+ "TOK_TABLEROWFORMAT",
+ "TOK_TABLESERIALIZER",
+ "TOK_FILEFORMAT_GENERIC", // For file formats not natively supported by Hive.
+ "TOK_TBLSEQUENCEFILE", // Stored as SequenceFile
+ "TOK_TBLTEXTFILE", // Stored as TextFile
+ "TOK_TBLRCFILE", // Stored as RCFile
+ "TOK_TBLORCFILE", // Stored as ORC File
+ "TOK_TABLEFILEFORMAT", // User-provided InputFormat and OutputFormat
+ "TOK_STORAGEHANDLER", // Storage handler
+ "TOK_TABLELOCATION",
+ "TOK_TABLEPROPERTIES"),
+ children)
+ if (notImplemented.exists(token => !token.isEmpty)) {
+ throw new NotImplementedError(
+ s"Unhandled clauses: ${notImplemented.flatten.map(dumpTree(_)).mkString("\n")}")
+ }
+
+ val (db, tableName) =
+ tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match {
+ case Seq(tableOnly) => (None, tableOnly)
+ case Seq(databaseName, table) => (Some(databaseName), table)
+ }
+ InsertIntoCreatedTable(db, tableName, nodeToPlan(query))
+
+ // If its not a "CREATE TABLE AS" like above then just pass it back to hive as a native command.
+ case Token("TOK_CREATETABLE", _) => NativePlaceholder
+
+ case Token("TOK_QUERY",
+ Token("TOK_FROM", fromClause :: Nil) ::
+ insertClauses) =>
+
+ // Return one query for each insert clause.
+ val queries = insertClauses.map { case Token("TOK_INSERT", singleInsert) =>
+ val (
+ intoClause ::
+ destClause ::
+ selectClause ::
+ selectDistinctClause ::
+ whereClause ::
+ groupByClause ::
+ orderByClause ::
+ sortByClause ::
+ clusterByClause ::
+ distributeByClause ::
+ limitClause ::
+ lateralViewClause :: Nil) = {
+ getClauses(
+ Seq(
+ "TOK_INSERT_INTO",
+ "TOK_DESTINATION",
+ "TOK_SELECT",
+ "TOK_SELECTDI",
+ "TOK_WHERE",
+ "TOK_GROUPBY",
+ "TOK_ORDERBY",
+ "TOK_SORTBY",
+ "TOK_CLUSTERBY",
+ "TOK_DISTRIBUTEBY",
+ "TOK_LIMIT",
+ "TOK_LATERAL_VIEW"),
+ singleInsert)
+ }
+
+ val relations = nodeToRelation(fromClause)
+ val withWhere = whereClause.map { whereNode =>
+ val Seq(whereExpr) = whereNode.getChildren.toSeq
+ Filter(nodeToExpr(whereExpr), relations)
+ }.getOrElse(relations)
+
+ val select =
+ (selectClause orElse selectDistinctClause).getOrElse(sys.error("No select clause."))
+
+ // Script transformations are expressed as a select clause with a single expression of type
+ // TOK_TRANSFORM
+ val transformation = select.getChildren.head match {
+ case Token("TOK_SELEXPR",
+ Token("TOK_TRANSFORM",
+ Token("TOK_EXPLIST", inputExprs) ::
+ Token("TOK_SERDE", Nil) ::
+ Token("TOK_RECORDWRITER", writerClause) ::
+ // TODO: Need to support other types of (in/out)put
+ Token(script, Nil) ::
+ Token("TOK_SERDE", serdeClause) ::
+ Token("TOK_RECORDREADER", readerClause) ::
+ outputClause :: Nil) :: Nil) =>
+
+ val output = outputClause match {
+ case Token("TOK_ALIASLIST", aliases) =>
+ aliases.map { case Token(name, Nil) => AttributeReference(name, StringType)() }
+ case Token("TOK_TABCOLLIST", attributes) =>
+ attributes.map { case Token("TOK_TABCOL", Token(name, Nil) :: dataType :: Nil) =>
+ AttributeReference(name, nodeToDataType(dataType))() }
+ }
+ val unescapedScript = BaseSemanticAnalyzer.unescapeSQLString(script)
+
+ Some(
+ logical.ScriptTransformation(
+ inputExprs.map(nodeToExpr),
+ unescapedScript,
+ output,
+ withWhere))
+ case _ => None
+ }
+
+ val withLateralView = lateralViewClause.map { lv =>
+ val Token("TOK_SELECT",
+ Token("TOK_SELEXPR", clauses) :: Nil) = lv.getChildren.head
+
+ val alias =
+ getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText
+
+ Generate(
+ nodesToGenerator(clauses),
+ join = true,
+ outer = false,
+ Some(alias.toLowerCase),
+ withWhere)
+ }.getOrElse(withWhere)
+
+
+ // The projection of the query can either be a normal projection, an aggregation
+ // (if there is a group by) or a script transformation.
+ val withProject = transformation.getOrElse {
+ // Not a transformation so must be either project or aggregation.
+ val selectExpressions = nameExpressions(select.getChildren.flatMap(selExprNodeToExpr))
+
+ groupByClause match {
+ case Some(groupBy) =>
+ Aggregate(groupBy.getChildren.map(nodeToExpr), selectExpressions, withLateralView)
+ case None =>
+ Project(selectExpressions, withLateralView)
+ }
+ }
+
+ val withDistinct =
+ if (selectDistinctClause.isDefined) Distinct(withProject) else withProject
+
+ val withSort =
+ (orderByClause, sortByClause, distributeByClause, clusterByClause) match {
+ case (Some(totalOrdering), None, None, None) =>
+ Sort(totalOrdering.getChildren.map(nodeToSortOrder), withDistinct)
+ case (None, Some(perPartitionOrdering), None, None) =>
+ SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder), withDistinct)
+ case (None, None, Some(partitionExprs), None) =>
+ Repartition(partitionExprs.getChildren.map(nodeToExpr), withDistinct)
+ case (None, Some(perPartitionOrdering), Some(partitionExprs), None) =>
+ SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder),
+ Repartition(partitionExprs.getChildren.map(nodeToExpr), withDistinct))
+ case (None, None, None, Some(clusterExprs)) =>
+ SortPartitions(clusterExprs.getChildren.map(nodeToExpr).map(SortOrder(_, Ascending)),
+ Repartition(clusterExprs.getChildren.map(nodeToExpr), withDistinct))
+ case (None, None, None, None) => withDistinct
+ case _ => sys.error("Unsupported set of ordering / distribution clauses.")
+ }
+
+ val withLimit =
+ limitClause.map(l => nodeToExpr(l.getChildren.head))
+ .map(StopAfter(_, withSort))
+ .getOrElse(withSort)
+
+ // TOK_INSERT_INTO means to add files to the table.
+ // TOK_DESTINATION means to overwrite the table.
+ val resultDestination =
+ (intoClause orElse destClause).getOrElse(sys.error("No destination found."))
+ val overwrite = if (intoClause.isEmpty) true else false
+ nodeToDest(
+ resultDestination,
+ withLimit,
+ overwrite)
+ }
+
+ // If there are multiple INSERTS just UNION them together into on query.
+ queries.reduceLeft(Union)
+
+ case Token("TOK_UNION", left :: right :: Nil) => Union(nodeToPlan(left), nodeToPlan(right))
+
+ case a: ASTNode =>
+ throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ")
+ }
+
+ val allJoinTokens = "(TOK_.*JOIN)".r
+ val laterViewToken = "TOK_LATERAL_VIEW(.*)".r
+ def nodeToRelation(node: Node): LogicalPlan = node match {
+ case Token("TOK_SUBQUERY",
+ query :: Token(alias, Nil) :: Nil) =>
+ Subquery(alias, nodeToPlan(query))
+
+ case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) =>
+ val Token("TOK_SELECT",
+ Token("TOK_SELEXPR", clauses) :: Nil) = selectClause
+
+ val alias = getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText
+
+ Generate(
+ nodesToGenerator(clauses),
+ join = true,
+ outer = isOuter.nonEmpty,
+ Some(alias.toLowerCase),
+ nodeToRelation(relationClause))
+
+ /* All relations, possibly with aliases or sampling clauses. */
+ case Token("TOK_TABREF", clauses) =>
+ // If the last clause is not a token then it's the alias of the table.
+ val (nonAliasClauses, aliasClause) =
+ if (clauses.last.getText.startsWith("TOK")) {
+ (clauses, None)
+ } else {
+ (clauses.dropRight(1), Some(clauses.last))
+ }
+
+ val (Some(tableNameParts) ::
+ splitSampleClause ::
+ bucketSampleClause :: Nil) = {
+ getClauses(Seq("TOK_TABNAME", "TOK_TABLESPLITSAMPLE", "TOK_TABLEBUCKETSAMPLE"),
+ nonAliasClauses)
+ }
+
+ val (db, tableName) =
+ tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match {
+ case Seq(tableOnly) => (None, tableOnly)
+ case Seq(databaseName, table) => (Some(databaseName), table)
+ }
+ val alias = aliasClause.map { case Token(a, Nil) => cleanIdentifier(a) }
+ val relation = UnresolvedRelation(db, tableName, alias)
+
+ // Apply sampling if requested.
+ (bucketSampleClause orElse splitSampleClause).map {
+ case Token("TOK_TABLESPLITSAMPLE",
+ Token("TOK_ROWCOUNT", Nil) ::
+ Token(count, Nil) :: Nil) =>
+ StopAfter(Literal(count.toInt), relation)
+ case Token("TOK_TABLESPLITSAMPLE",
+ Token("TOK_PERCENT", Nil) ::
+ Token(fraction, Nil) :: Nil) =>
+ Sample(fraction.toDouble, withReplacement = false, (math.random * 1000).toInt, relation)
+ case Token("TOK_TABLEBUCKETSAMPLE",
+ Token(numerator, Nil) ::
+ Token(denominator, Nil) :: Nil) =>
+ val fraction = numerator.toDouble / denominator.toDouble
+ Sample(fraction, withReplacement = false, (math.random * 1000).toInt, relation)
+ case a: ASTNode =>
+ throw new NotImplementedError(
+ s"""No parse rules for sampling clause: ${a.getType}, text: ${a.getText} :
+ |${dumpTree(a).toString}" +
+ """.stripMargin)
+ }.getOrElse(relation)
+
+ case Token("TOK_UNIQUEJOIN", joinArgs) =>
+ val tableOrdinals =
+ joinArgs.zipWithIndex.filter {
+ case (arg, i) => arg.getText == "TOK_TABREF"
+ }.map(_._2)
+
+ val isPreserved = tableOrdinals.map(i => (i - 1 < 0) || joinArgs(i - 1).getText == "PRESERVE")
+ val tables = tableOrdinals.map(i => nodeToRelation(joinArgs(i)))
+ val joinExpressions = tableOrdinals.map(i => joinArgs(i + 1).getChildren.map(nodeToExpr))
+
+ val joinConditions = joinExpressions.sliding(2).map {
+ case Seq(c1, c2) =>
+ val predicates = (c1, c2).zipped.map { case (e1, e2) => Equals(e1, e2): Expression }
+ predicates.reduceLeft(And)
+ }.toBuffer
+
+ val joinType = isPreserved.sliding(2).map {
+ case Seq(true, true) => FullOuter
+ case Seq(true, false) => LeftOuter
+ case Seq(false, true) => RightOuter
+ case Seq(false, false) => Inner
+ }.toBuffer
+
+ val joinedTables = tables.reduceLeft(Join(_,_, Inner, None))
+
+ // Must be transform down.
+ val joinedResult = joinedTables transform {
+ case j: Join =>
+ j.copy(
+ condition = Some(joinConditions.remove(joinConditions.length - 1)),
+ joinType = joinType.remove(joinType.length - 1))
+ }
+
+ val groups = (0 until joinExpressions.head.size).map(i => Coalesce(joinExpressions.map(_(i))))
+
+ // Unique join is not really the same as an outer join so we must group together results where
+ // the joinExpressions are the same, taking the First of each value is only okay because the
+ // user of a unique join is implicitly promising that there is only one result.
+ // TODO: This doesn't actually work since [[Star]] is not a valid aggregate expression.
+ // instead we should figure out how important supporting this feature is and whether it is
+ // worth the number of hacks that will be required to implement it. Namely, we need to add
+ // some sort of mapped star expansion that would expand all child output row to be similarly
+ // named output expressions where some aggregate expression has been applied (i.e. First).
+ ??? /// Aggregate(groups, Star(None, First(_)) :: Nil, joinedResult)
+
+ case Token(allJoinTokens(joinToken),
+ relation1 ::
+ relation2 :: other) =>
+ assert(other.size <= 1, s"Unhandled join child ${other}")
+ val joinType = joinToken match {
+ case "TOK_JOIN" => Inner
+ case "TOK_RIGHTOUTERJOIN" => RightOuter
+ case "TOK_LEFTOUTERJOIN" => LeftOuter
+ case "TOK_FULLOUTERJOIN" => FullOuter
+ }
+ assert(other.size <= 1, "Unhandled join clauses.")
+ Join(nodeToRelation(relation1),
+ nodeToRelation(relation2),
+ joinType,
+ other.headOption.map(nodeToExpr))
+
+ case a: ASTNode =>
+ throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ")
+ }
+
+ def nodeToSortOrder(node: Node): SortOrder = node match {
+ case Token("TOK_TABSORTCOLNAMEASC", sortExpr :: Nil) =>
+ SortOrder(nodeToExpr(sortExpr), Ascending)
+ case Token("TOK_TABSORTCOLNAMEDESC", sortExpr :: Nil) =>
+ SortOrder(nodeToExpr(sortExpr), Descending)
+
+ case a: ASTNode =>
+ throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ")
+ }
+
+ val destinationToken = "TOK_DESTINATION|TOK_INSERT_INTO".r
+ protected def nodeToDest(
+ node: Node,
+ query: LogicalPlan,
+ overwrite: Boolean): LogicalPlan = node match {
+ case Token(destinationToken(),
+ Token("TOK_DIR",
+ Token("TOK_TMP_FILE", Nil) :: Nil) :: Nil) =>
+ query
+
+ case Token(destinationToken(),
+ Token("TOK_TAB",
+ tableArgs) :: Nil) =>
+ val Some(tableNameParts) :: partitionClause :: Nil =
+ getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs)
+
+ val (db, tableName) =
+ tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match {
+ case Seq(tableOnly) => (None, tableOnly)
+ case Seq(databaseName, table) => (Some(databaseName), table)
+ }
+
+ val partitionKeys = partitionClause.map(_.getChildren.map {
+ // Parse partitions. We also make keys case insensitive.
+ case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) =>
+ cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value))
+ case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) =>
+ cleanIdentifier(key.toLowerCase) -> None
+ }.toMap).getOrElse(Map.empty)
+
+ if (partitionKeys.values.exists(p => p.isEmpty)) {
+ throw new NotImplementedError(s"Do not support INSERT INTO/OVERWRITE with" +
+ s"dynamic partitioning.")
+ }
+
+ InsertIntoTable(UnresolvedRelation(db, tableName, None), partitionKeys, query, overwrite)
+
+ case a: ASTNode =>
+ throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ")
+ }
+
+ protected def selExprNodeToExpr(node: Node): Option[Expression] = node match {
+ case Token("TOK_SELEXPR",
+ e :: Nil) =>
+ Some(nodeToExpr(e))
+
+ case Token("TOK_SELEXPR",
+ e :: Token(alias, Nil) :: Nil) =>
+ Some(Alias(nodeToExpr(e), alias)())
+
+ /* Hints are ignored */
+ case Token("TOK_HINTLIST", _) => None
+
+ case a: ASTNode =>
+ throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ")
+ }
+
+
+ protected val escapedIdentifier = "`([^`]+)`".r
+ /** Strips backticks from ident if present */
+ protected def cleanIdentifier(ident: String): String = ident match {
+ case escapedIdentifier(i) => i
+ case plainIdent => plainIdent
+ }
+
+ val numericAstTypes = Seq(
+ HiveParser.Number,
+ HiveParser.TinyintLiteral,
+ HiveParser.SmallintLiteral,
+ HiveParser.BigintLiteral)
+
+ /* Case insensitive matches */
+ val COUNT = "(?i)COUNT".r
+ val AVG = "(?i)AVG".r
+ val SUM = "(?i)SUM".r
+ val RAND = "(?i)RAND".r
+ val AND = "(?i)AND".r
+ val OR = "(?i)OR".r
+ val NOT = "(?i)NOT".r
+ val TRUE = "(?i)TRUE".r
+ val FALSE = "(?i)FALSE".r
+
+ protected def nodeToExpr(node: Node): Expression = node match {
+ /* Attribute References */
+ case Token("TOK_TABLE_OR_COL",
+ Token(name, Nil) :: Nil) =>
+ UnresolvedAttribute(cleanIdentifier(name))
+ case Token(".", qualifier :: Token(attr, Nil) :: Nil) =>
+ nodeToExpr(qualifier) match {
+ case UnresolvedAttribute(qualifierName) =>
+ UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr))
+ // The precidence for . seems to be wrong, so [] binds tighter an we need to go inside to
+ // find the underlying attribute references.
+ case GetItem(UnresolvedAttribute(qualifierName), ordinal) =>
+ GetItem(UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr)), ordinal)
+ }
+
+ /* Stars (*) */
+ case Token("TOK_ALLCOLREF", Nil) => Star(None)
+ // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only
+ // has a single child which is tableName.
+ case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) =>
+ Star(Some(name))
+
+ /* Aggregate Functions */
+ case Token("TOK_FUNCTION", Token(AVG(), Nil) :: arg :: Nil) => Average(nodeToExpr(arg))
+ case Token("TOK_FUNCTION", Token(COUNT(), Nil) :: arg :: Nil) => Count(nodeToExpr(arg))
+ case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => Count(Literal(1))
+ case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => CountDistinct(args.map(nodeToExpr))
+ case Token("TOK_FUNCTION", Token(SUM(), Nil) :: arg :: Nil) => Sum(nodeToExpr(arg))
+ case Token("TOK_FUNCTIONDI", Token(SUM(), Nil) :: arg :: Nil) => SumDistinct(nodeToExpr(arg))
+
+ /* Casts */
+ case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), StringType)
+ case Token("TOK_FUNCTION", Token("TOK_VARCHAR", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), StringType)
+ case Token("TOK_FUNCTION", Token("TOK_INT", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), IntegerType)
+ case Token("TOK_FUNCTION", Token("TOK_BIGINT", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), LongType)
+ case Token("TOK_FUNCTION", Token("TOK_FLOAT", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), FloatType)
+ case Token("TOK_FUNCTION", Token("TOK_DOUBLE", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), DoubleType)
+ case Token("TOK_FUNCTION", Token("TOK_SMALLINT", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), ShortType)
+ case Token("TOK_FUNCTION", Token("TOK_TINYINT", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), ByteType)
+ case Token("TOK_FUNCTION", Token("TOK_BINARY", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), BinaryType)
+ case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), BooleanType)
+ case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), DecimalType)
+
+ /* Arithmetic */
+ case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child))
+ case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right))
+ case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right))
+ case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right))
+ case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right))
+ case Token("DIV", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right))
+ case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right))
+
+ /* Comparisons */
+ case Token("=", left :: right:: Nil) => Equals(nodeToExpr(left), nodeToExpr(right))
+ case Token("!=", left :: right:: Nil) => Not(Equals(nodeToExpr(left), nodeToExpr(right)))
+ case Token("<>", left :: right:: Nil) => Not(Equals(nodeToExpr(left), nodeToExpr(right)))
+ case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right))
+ case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right))
+ case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right))
+ case Token("<=", left :: right:: Nil) => LessThanOrEqual(nodeToExpr(left), nodeToExpr(right))
+ case Token("LIKE", left :: right:: Nil) =>
+ UnresolvedFunction("LIKE", Seq(nodeToExpr(left), nodeToExpr(right)))
+ case Token("RLIKE", left :: right:: Nil) =>
+ UnresolvedFunction("RLIKE", Seq(nodeToExpr(left), nodeToExpr(right)))
+ case Token("REGEXP", left :: right:: Nil) =>
+ UnresolvedFunction("REGEXP", Seq(nodeToExpr(left), nodeToExpr(right)))
+ case Token("TOK_FUNCTION", Token("TOK_ISNOTNULL", Nil) :: child :: Nil) =>
+ IsNotNull(nodeToExpr(child))
+ case Token("TOK_FUNCTION", Token("TOK_ISNULL", Nil) :: child :: Nil) =>
+ IsNull(nodeToExpr(child))
+ case Token("TOK_FUNCTION", Token("IN", Nil) :: value :: list) =>
+ In(nodeToExpr(value), list.map(nodeToExpr))
+
+ /* Boolean Logic */
+ case Token(AND(), left :: right:: Nil) => And(nodeToExpr(left), nodeToExpr(right))
+ case Token(OR(), left :: right:: Nil) => Or(nodeToExpr(left), nodeToExpr(right))
+ case Token(NOT(), child :: Nil) => Not(nodeToExpr(child))
+
+ /* Complex datatype manipulation */
+ case Token("[", child :: ordinal :: Nil) =>
+ GetItem(nodeToExpr(child), nodeToExpr(ordinal))
+
+ /* Other functions */
+ case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand
+
+ /* UDFs - Must be last otherwise will preempt built in functions */
+ case Token("TOK_FUNCTION", Token(name, Nil) :: args) =>
+ UnresolvedFunction(name, args.map(nodeToExpr))
+ case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) =>
+ UnresolvedFunction(name, Star(None) :: Nil)
+
+ /* Literals */
+ case Token("TOK_NULL", Nil) => Literal(null, NullType)
+ case Token(TRUE(), Nil) => Literal(true, BooleanType)
+ case Token(FALSE(), Nil) => Literal(false, BooleanType)
+ case Token("TOK_STRINGLITERALSEQUENCE", strings) =>
+ Literal(strings.map(s => BaseSemanticAnalyzer.unescapeSQLString(s.getText)).mkString)
+
+ // This code is adapted from
+ // /ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java#L223
+ case ast: ASTNode if numericAstTypes contains ast.getType =>
+ var v: Literal = null
+ try {
+ if (ast.getText.endsWith("L")) {
+ // Literal bigint.
+ v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toLong, LongType)
+ } else if (ast.getText.endsWith("S")) {
+ // Literal smallint.
+ v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toShort, ShortType)
+ } else if (ast.getText.endsWith("Y")) {
+ // Literal tinyint.
+ v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toByte, ByteType)
+ } else if (ast.getText.endsWith("BD")) {
+ // Literal decimal
+ val strVal = ast.getText.substring(0, ast.getText.length() - 2)
+ BigDecimal(strVal)
+ } else {
+ v = Literal(ast.getText.toDouble, DoubleType)
+ v = Literal(ast.getText.toLong, LongType)
+ v = Literal(ast.getText.toInt, IntegerType)
+ }
+ } catch {
+ case nfe: NumberFormatException => // Do nothing
+ }
+
+ if (v == null) {
+ sys.error(s"Failed to parse number ${ast.getText}")
+ } else {
+ v
+ }
+
+ case ast: ASTNode if ast.getType == HiveParser.StringLiteral =>
+ Literal(BaseSemanticAnalyzer.unescapeSQLString(ast.getText))
+
+ case a: ASTNode =>
+ throw new NotImplementedError(
+ s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText} :
+ |${dumpTree(a).toString}" +
+ """.stripMargin)
+ }
+
+
+ val explode = "(?i)explode".r
+ def nodesToGenerator(nodes: Seq[Node]): Generator = {
+ val function = nodes.head
+
+ val attributes = nodes.flatMap {
+ case Token(a, Nil) => a.toLowerCase :: Nil
+ case _ => Nil
+ }
+
+ function match {
+ case Token("TOK_FUNCTION", Token(explode(), Nil) :: child :: Nil) =>
+ Explode(attributes, nodeToExpr(child))
+
+ case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) =>
+ HiveGenericUdtf(functionName, attributes, children.map(nodeToExpr))
+
+ case a: ASTNode =>
+ throw new NotImplementedError(
+ s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText}, tree:
+ |${dumpTree(a).toString}
+ """.stripMargin)
+ }
+ }
+
+ def dumpTree(node: Node, builder: StringBuilder = new StringBuilder, indent: Int = 0)
+ : StringBuilder = {
+ node match {
+ case a: ASTNode => builder.append((" " * indent) + a.getText + "\n")
+ case other => sys.error(s"Non ASTNode encountered: $other")
+ }
+
+ Option(node.getChildren).map(_.toList).getOrElse(Nil).foreach(dumpTree(_, builder, indent + 1))
+ builder
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
new file mode 100644
index 0000000000..92d84208ab
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+
+import catalyst.expressions._
+import catalyst.planning._
+import catalyst.plans._
+import catalyst.plans.logical.{BaseRelation, LogicalPlan}
+
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.parquet.{ParquetRelation, InsertIntoParquetTable, ParquetTableScan}
+
+trait HiveStrategies {
+ // Possibly being too clever with types here... or not clever enough.
+ self: SQLContext#SparkPlanner =>
+
+ val hiveContext: HiveContext
+
+ object Scripts extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case logical.ScriptTransformation(input, script, output, child) =>
+ ScriptTransformation(input, script, output, planLater(child))(hiveContext) :: Nil
+ case _ => Nil
+ }
+ }
+
+ object DataSinks extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) =>
+ InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil
+ case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) =>
+ InsertIntoParquetTable(table, planLater(child))(hiveContext.sparkContext) :: Nil
+ case _ => Nil
+ }
+ }
+
+ object HiveTableScans extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ // Push attributes into table scan when possible.
+ case p @ logical.Project(projectList, m: MetastoreRelation) if isSimpleProject(projectList) =>
+ HiveTableScan(projectList.asInstanceOf[Seq[Attribute]], m, None)(hiveContext) :: Nil
+ case m: MetastoreRelation =>
+ HiveTableScan(m.output, m, None)(hiveContext) :: Nil
+ case _ => Nil
+ }
+ }
+
+ /**
+ * A strategy used to detect filtering predicates on top of a partitioned relation to help
+ * partition pruning.
+ *
+ * This strategy itself doesn't perform partition pruning, it just collects and combines all the
+ * partition pruning predicates and pass them down to the underlying [[HiveTableScan]] operator,
+ * which does the actual pruning work.
+ */
+ object PartitionPrunings extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case p @ FilteredOperation(predicates, relation: MetastoreRelation)
+ if relation.isPartitioned =>
+
+ val partitionKeyIds = relation.partitionKeys.map(_.id).toSet
+
+ // Filter out all predicates that only deal with partition keys
+ val (pruningPredicates, otherPredicates) = predicates.partition {
+ _.references.map(_.id).subsetOf(partitionKeyIds)
+ }
+
+ val scan = HiveTableScan(
+ relation.output, relation, pruningPredicates.reduceLeftOption(And))(hiveContext)
+
+ otherPredicates
+ .reduceLeftOption(And)
+ .map(Filter(_, scan))
+ .getOrElse(scan) :: Nil
+
+ case _ =>
+ Nil
+ }
+ }
+
+ /**
+ * A strategy that detects projects and filters over some relation and applies column pruning if
+ * possible. Partition pruning is applied first if the relation is partitioned.
+ */
+ object ColumnPrunings extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ // TODO(andre): the current mix of HiveRelation and ParquetRelation
+ // here appears artificial; try to refactor to break it into two
+ case PhysicalOperation(projectList, predicates, relation: BaseRelation) =>
+ val predicateOpt = predicates.reduceOption(And)
+ val predicateRefs = predicateOpt.map(_.references).getOrElse(Set.empty)
+ val projectRefs = projectList.flatMap(_.references)
+
+ // To figure out what columns to preserve after column pruning, we need to consider:
+ //
+ // 1. Columns referenced by the project list (order preserved)
+ // 2. Columns referenced by filtering predicates but not by project list
+ // 3. Relation output
+ //
+ // Then the final result is ((1 union 2) intersect 3)
+ val prunedCols = (projectRefs ++ (predicateRefs -- projectRefs)).intersect(relation.output)
+
+ val filteredScans =
+ if (relation.isPartitioned) { // from here on relation must be a [[MetaStoreRelation]]
+ // Applies partition pruning first for partitioned table
+ val filteredRelation = predicateOpt.map(logical.Filter(_, relation)).getOrElse(relation)
+ PartitionPrunings(filteredRelation).view.map(_.transform {
+ case scan: HiveTableScan =>
+ scan.copy(attributes = prunedCols)(hiveContext)
+ })
+ } else {
+ val scan = relation match {
+ case MetastoreRelation(_, _, _) => {
+ HiveTableScan(
+ prunedCols,
+ relation.asInstanceOf[MetastoreRelation],
+ None)(hiveContext)
+ }
+ case ParquetRelation(_, _) => {
+ ParquetTableScan(
+ relation.output,
+ relation.asInstanceOf[ParquetRelation],
+ None)(hiveContext.sparkContext)
+ .pruneColumns(prunedCols)
+ }
+ }
+ predicateOpt.map(execution.Filter(_, scan)).getOrElse(scan) :: Nil
+ }
+
+ if (isSimpleProject(projectList) && prunedCols == projectRefs) {
+ filteredScans
+ } else {
+ filteredScans.view.map(execution.Project(projectList, _))
+ }
+
+ case _ =>
+ Nil
+ }
+ }
+
+ /**
+ * Returns true if `projectList` only performs column pruning and does not evaluate other
+ * complex expressions.
+ */
+ def isSimpleProject(projectList: Seq[NamedExpression]) = {
+ projectList.forall(_.isInstanceOf[Attribute])
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ScriptTransformation.scala
new file mode 100644
index 0000000000..f20e9d4de4
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ScriptTransformation.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+
+import java.io.{InputStreamReader, BufferedReader}
+
+import catalyst.expressions._
+import org.apache.spark.sql.execution._
+
+import scala.collection.JavaConversions._
+
+/**
+ * Transforms the input by forking and running the specified script.
+ *
+ * @param input the set of expression that should be passed to the script.
+ * @param script the command that should be executed.
+ * @param output the attributes that are produced by the script.
+ */
+case class ScriptTransformation(
+ input: Seq[Expression],
+ script: String,
+ output: Seq[Attribute],
+ child: SparkPlan)(@transient sc: HiveContext)
+ extends UnaryNode {
+
+ override def otherCopyArgs = sc :: Nil
+
+ def execute() = {
+ child.execute().mapPartitions { iter =>
+ val cmd = List("/bin/bash", "-c", script)
+ val builder = new ProcessBuilder(cmd)
+ val proc = builder.start()
+ val inputStream = proc.getInputStream
+ val outputStream = proc.getOutputStream
+ val reader = new BufferedReader(new InputStreamReader(inputStream))
+
+ // TODO: This should be exposed as an iterator instead of reading in all the data at once.
+ val outputLines = collection.mutable.ArrayBuffer[Row]()
+ val readerThread = new Thread("Transform OutputReader") {
+ override def run() {
+ var curLine = reader.readLine()
+ while (curLine != null) {
+ // TODO: Use SerDe
+ outputLines += new GenericRow(curLine.split("\t").asInstanceOf[Array[Any]])
+ curLine = reader.readLine()
+ }
+ }
+ }
+ readerThread.start()
+ val outputProjection = new Projection(input)
+ iter
+ .map(outputProjection)
+ // TODO: Use SerDe
+ .map(_.mkString("", "\t", "\n").getBytes).foreach(outputStream.write)
+ outputStream.close()
+ readerThread.join()
+ outputLines.toIterator
+ }
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
new file mode 100644
index 0000000000..71d751cbc4
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants._
+import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable}
+import org.apache.hadoop.hive.ql.plan.TableDesc
+import org.apache.hadoop.hive.serde2.Deserializer
+import org.apache.hadoop.hive.ql.exec.Utilities
+import org.apache.hadoop.io.Writable
+import org.apache.hadoop.fs.{Path, PathFilter}
+import org.apache.hadoop.mapred.{FileInputFormat, JobConf, InputFormat}
+
+import org.apache.spark.SerializableWritable
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.{HadoopRDD, UnionRDD, EmptyRDD, RDD}
+
+
+/**
+ * A trait for subclasses that handle table scans.
+ */
+private[hive] sealed trait TableReader {
+ def makeRDDForTable(hiveTable: HiveTable): RDD[_]
+
+ def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[_]
+
+}
+
+
+/**
+ * Helper class for scanning tables stored in Hadoop - e.g., to read Hive tables that reside in the
+ * data warehouse directory.
+ */
+private[hive]
+class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveContext)
+ extends TableReader {
+
+ // Choose the minimum number of splits. If mapred.map.tasks is set, then use that unless
+ // it is smaller than what Spark suggests.
+ private val _minSplitsPerRDD = math.max(
+ sc.hiveconf.getInt("mapred.map.tasks", 1), sc.sparkContext.defaultMinSplits)
+
+
+ // TODO: set aws s3 credentials.
+
+ private val _broadcastedHiveConf =
+ sc.sparkContext.broadcast(new SerializableWritable(sc.hiveconf))
+
+ def broadcastedHiveConf = _broadcastedHiveConf
+
+ def hiveConf = _broadcastedHiveConf.value.value
+
+ override def makeRDDForTable(hiveTable: HiveTable): RDD[_] =
+ makeRDDForTable(
+ hiveTable,
+ _tableDesc.getDeserializerClass.asInstanceOf[Class[Deserializer]],
+ filterOpt = None)
+
+ /**
+ * Creates a Hadoop RDD to read data from the target table's data directory. Returns a transformed
+ * RDD that contains deserialized rows.
+ *
+ * @param hiveTable Hive metadata for the table being scanned.
+ * @param deserializerClass Class of the SerDe used to deserialize Writables read from Hadoop.
+ * @param filterOpt If defined, then the filter is used to reject files contained in the data
+ * directory being read. If None, then all files are accepted.
+ */
+ def makeRDDForTable(
+ hiveTable: HiveTable,
+ deserializerClass: Class[_ <: Deserializer],
+ filterOpt: Option[PathFilter]): RDD[_] =
+ {
+ assert(!hiveTable.isPartitioned, """makeRDDForTable() cannot be called on a partitioned table,
+ since input formats may differ across partitions. Use makeRDDForTablePartitions() instead.""")
+
+ // Create local references to member variables, so that the entire `this` object won't be
+ // serialized in the closure below.
+ val tableDesc = _tableDesc
+ val broadcastedHiveConf = _broadcastedHiveConf
+
+ val tablePath = hiveTable.getPath
+ val inputPathStr = applyFilterIfNeeded(tablePath, filterOpt)
+
+ //logDebug("Table input: %s".format(tablePath))
+ val ifc = hiveTable.getInputFormatClass
+ .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]]
+ val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc)
+
+ val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter =>
+ val hconf = broadcastedHiveConf.value.value
+ val deserializer = deserializerClass.newInstance()
+ deserializer.initialize(hconf, tableDesc.getProperties)
+
+ // Deserialize each Writable to get the row value.
+ iter.map {
+ case v: Writable => deserializer.deserialize(v)
+ case value =>
+ sys.error(s"Unable to deserialize non-Writable: $value of ${value.getClass.getName}")
+ }
+ }
+ deserializedHadoopRDD
+ }
+
+ override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[_] = {
+ val partitionToDeserializer = partitions.map(part =>
+ (part, part.getDeserializer.getClass.asInstanceOf[Class[Deserializer]])).toMap
+ makeRDDForPartitionedTable(partitionToDeserializer, filterOpt = None)
+ }
+
+ /**
+ * Create a HadoopRDD for every partition key specified in the query. Note that for on-disk Hive
+ * tables, a data directory is created for each partition corresponding to keys specified using
+ * 'PARTITION BY'.
+ *
+ * @param partitionToDeserializer Mapping from a Hive Partition metadata object to the SerDe
+ * class to use to deserialize input Writables from the corresponding partition.
+ * @param filterOpt If defined, then the filter is used to reject files contained in the data
+ * subdirectory of each partition being read. If None, then all files are accepted.
+ */
+ def makeRDDForPartitionedTable(
+ partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]],
+ filterOpt: Option[PathFilter]): RDD[_] =
+ {
+ val hivePartitionRDDs = partitionToDeserializer.map { case (partition, partDeserializer) =>
+ val partDesc = Utilities.getPartitionDesc(partition)
+ val partPath = partition.getPartitionPath
+ val inputPathStr = applyFilterIfNeeded(partPath, filterOpt)
+ val ifc = partDesc.getInputFileFormatClass
+ .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]]
+ // Get partition field info
+ val partSpec = partDesc.getPartSpec
+ val partProps = partDesc.getProperties
+
+ val partColsDelimited: String = partProps.getProperty(META_TABLE_PARTITION_COLUMNS)
+ // Partitioning columns are delimited by "/"
+ val partCols = partColsDelimited.trim().split("/").toSeq
+ // 'partValues[i]' contains the value for the partitioning column at 'partCols[i]'.
+ val partValues = if (partSpec == null) {
+ Array.fill(partCols.size)(new String)
+ } else {
+ partCols.map(col => new String(partSpec.get(col))).toArray
+ }
+
+ // Create local references so that the outer object isn't serialized.
+ val tableDesc = _tableDesc
+ val broadcastedHiveConf = _broadcastedHiveConf
+ val localDeserializer = partDeserializer
+
+ val hivePartitionRDD = createHadoopRdd(tableDesc, inputPathStr, ifc)
+ hivePartitionRDD.mapPartitions { iter =>
+ val hconf = broadcastedHiveConf.value.value
+ val rowWithPartArr = new Array[Object](2)
+ // Map each tuple to a row object
+ iter.map { value =>
+ val deserializer = localDeserializer.newInstance()
+ deserializer.initialize(hconf, partProps)
+ val deserializedRow = deserializer.deserialize(value)
+ rowWithPartArr.update(0, deserializedRow)
+ rowWithPartArr.update(1, partValues)
+ rowWithPartArr.asInstanceOf[Object]
+ }
+ }
+ }.toSeq
+ // Even if we don't use any partitions, we still need an empty RDD
+ if (hivePartitionRDDs.size == 0) {
+ new EmptyRDD[Object](sc.sparkContext)
+ } else {
+ new UnionRDD(hivePartitionRDDs(0).context, hivePartitionRDDs)
+ }
+ }
+
+ /**
+ * If `filterOpt` is defined, then it will be used to filter files from `path`. These files are
+ * returned in a single, comma-separated string.
+ */
+ private def applyFilterIfNeeded(path: Path, filterOpt: Option[PathFilter]): String = {
+ filterOpt match {
+ case Some(filter) =>
+ val fs = path.getFileSystem(sc.hiveconf)
+ val filteredFiles = fs.listStatus(path, filter).map(_.getPath.toString)
+ filteredFiles.mkString(",")
+ case None => path.toString
+ }
+ }
+
+ /**
+ * Creates a HadoopRDD based on the broadcasted HiveConf and other job properties that will be
+ * applied locally on each slave.
+ */
+ private def createHadoopRdd(
+ tableDesc: TableDesc,
+ path: String,
+ inputFormatClass: Class[InputFormat[Writable, Writable]])
+ : RDD[Writable] = {
+ val initializeJobConfFunc = HadoopTableReader.initializeLocalJobConfFunc(path, tableDesc) _
+
+ val rdd = new HadoopRDD(
+ sc.sparkContext,
+ _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableWritable[Configuration]]],
+ Some(initializeJobConfFunc),
+ inputFormatClass,
+ classOf[Writable],
+ classOf[Writable],
+ _minSplitsPerRDD)
+
+ // Only take the value (skip the key) because Hive works only with values.
+ rdd.map(_._2)
+ }
+
+}
+
+private[hive] object HadoopTableReader {
+
+ /**
+ * Curried. After given an argument for 'path', the resulting JobConf => Unit closure is used to
+ * instantiate a HadoopRDD.
+ */
+ def initializeLocalJobConfFunc(path: String, tableDesc: TableDesc)(jobConf: JobConf) {
+ FileInputFormat.setInputPaths(jobConf, path)
+ if (tableDesc != null) {
+ Utilities.copyTableJobPropertiesToConf(tableDesc, jobConf)
+ }
+ val bufferSize = System.getProperty("spark.buffer.size", "65536")
+ jobConf.set("io.file.buffer.size", bufferSize)
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
new file mode 100644
index 0000000000..17ae4ef63c
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
@@ -0,0 +1,341 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+
+import java.io.File
+import java.util.{Set => JavaSet}
+
+import scala.collection.mutable
+import scala.collection.JavaConversions._
+import scala.language.implicitConversions
+
+import org.apache.hadoop.hive.metastore.api.{SerDeInfo, StorageDescriptor}
+import org.apache.hadoop.hive.metastore.MetaStoreUtils
+import org.apache.hadoop.hive.ql.exec.FunctionRegistry
+import org.apache.hadoop.hive.ql.io.avro.{AvroContainerOutputFormat, AvroContainerInputFormat}
+import org.apache.hadoop.hive.ql.metadata.Table
+import org.apache.hadoop.hive.serde2.avro.AvroSerDe
+import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
+import org.apache.hadoop.hive.serde2.RegexSerDe
+
+import org.apache.spark.{SparkContext, SparkConf}
+
+import catalyst.analysis._
+import catalyst.plans.logical.{LogicalPlan, NativeCommand}
+import catalyst.util._
+
+object TestHive
+ extends TestHiveContext(new SparkContext("local", "TestSQLContext", new SparkConf()))
+
+/**
+ * A locally running test instance of Spark's Hive execution engine.
+ *
+ * Data from [[testTables]] will be automatically loaded whenever a query is run over those tables.
+ * Calling [[reset]] will delete all tables and other state in the database, leaving the database
+ * in a "clean" state.
+ *
+ * TestHive is singleton object version of this class because instantiating multiple copies of the
+ * hive metastore seems to lead to weird non-deterministic failures. Therefore, the execution of
+ * testcases that rely on TestHive must be serialized.
+ */
+class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) {
+ self =>
+
+ // By clearing the port we force Spark to pick a new one. This allows us to rerun tests
+ // without restarting the JVM.
+ System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
+
+ override lazy val warehousePath = getTempFilePath("sparkHiveWarehouse").getCanonicalPath
+ override lazy val metastorePath = getTempFilePath("sparkHiveMetastore").getCanonicalPath
+
+ /** The location of the compiled hive distribution */
+ lazy val hiveHome = envVarToFile("HIVE_HOME")
+ /** The location of the hive source code. */
+ lazy val hiveDevHome = envVarToFile("HIVE_DEV_HOME")
+
+ // Override so we can intercept relative paths and rewrite them to point at hive.
+ override def runSqlHive(sql: String): Seq[String] = super.runSqlHive(rewritePaths(sql))
+
+ override def executePlan(plan: LogicalPlan): this.QueryExecution =
+ new this.QueryExecution { val logical = plan }
+
+ /**
+ * Returns the value of specified environmental variable as a [[java.io.File]] after checking
+ * to ensure it exists
+ */
+ private def envVarToFile(envVar: String): Option[File] = {
+ Option(System.getenv(envVar)).map(new File(_))
+ }
+
+ /**
+ * Replaces relative paths to the parent directory "../" with hiveDevHome since this is how the
+ * hive test cases assume the system is set up.
+ */
+ private def rewritePaths(cmd: String): String =
+ if (cmd.toUpperCase contains "LOAD DATA") {
+ val testDataLocation =
+ hiveDevHome.map(_.getCanonicalPath).getOrElse(inRepoTests.getCanonicalPath)
+ cmd.replaceAll("\\.\\.", testDataLocation)
+ } else {
+ cmd
+ }
+
+ val hiveFilesTemp = File.createTempFile("catalystHiveFiles", "")
+ hiveFilesTemp.delete()
+ hiveFilesTemp.mkdir()
+
+ val inRepoTests = new File("src/test/resources/")
+ def getHiveFile(path: String): File = {
+ val stripped = path.replaceAll("""\.\.\/""", "")
+ hiveDevHome
+ .map(new File(_, stripped))
+ .filter(_.exists)
+ .getOrElse(new File(inRepoTests, stripped))
+ }
+
+ val describedTable = "DESCRIBE (\\w+)".r
+
+ class SqlQueryExecution(sql: String) extends this.QueryExecution {
+ lazy val logical = HiveQl.parseSql(sql)
+ def hiveExec() = runSqlHive(sql)
+ override def toString = sql + "\n" + super.toString
+ }
+
+ /**
+ * Override QueryExecution with special debug workflow.
+ */
+ abstract class QueryExecution extends super.QueryExecution {
+ override lazy val analyzed = {
+ val describedTables = logical match {
+ case NativeCommand(describedTable(tbl)) => tbl :: Nil
+ case _ => Nil
+ }
+
+ // Make sure any test tables referenced are loaded.
+ val referencedTables =
+ describedTables ++
+ logical.collect { case UnresolvedRelation(databaseName, name, _) => name }
+ val referencedTestTables = referencedTables.filter(testTables.contains)
+ logger.debug(s"Query references test tables: ${referencedTestTables.mkString(", ")}")
+ referencedTestTables.foreach(loadTestTable)
+ // Proceed with analysis.
+ analyzer(logical)
+ }
+ }
+
+ case class TestTable(name: String, commands: (()=>Unit)*)
+
+ implicit class SqlCmd(sql: String) {
+ def cmd = () => new SqlQueryExecution(sql).stringResult(): Unit
+ }
+
+ /**
+ * A list of test tables and the DDL required to initialize them. A test table is loaded on
+ * demand when a query are run against it.
+ */
+ lazy val testTables = new mutable.HashMap[String, TestTable]()
+ def registerTestTable(testTable: TestTable) = testTables += (testTable.name -> testTable)
+
+ // The test tables that are defined in the Hive QTestUtil.
+ // /itests/util/src/main/java/org/apache/hadoop/hive/ql/QTestUtil.java
+ val hiveQTestUtilTables = Seq(
+ TestTable("src",
+ "CREATE TABLE src (key INT, value STRING)".cmd,
+ s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd),
+ TestTable("src1",
+ "CREATE TABLE src1 (key INT, value STRING)".cmd,
+ s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd),
+ TestTable("dest1",
+ "CREATE TABLE IF NOT EXISTS dest1 (key INT, value STRING)".cmd),
+ TestTable("dest2",
+ "CREATE TABLE IF NOT EXISTS dest2 (key INT, value STRING)".cmd),
+ TestTable("dest3",
+ "CREATE TABLE IF NOT EXISTS dest3 (key INT, value STRING)".cmd),
+ TestTable("srcpart", () => {
+ runSqlHive(
+ "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)")
+ for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) {
+ runSqlHive(
+ s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}'
+ |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr')
+ """.stripMargin)
+ }
+ }),
+ TestTable("srcpart1", () => {
+ runSqlHive("CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)")
+ for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) {
+ runSqlHive(
+ s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}'
+ |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr')
+ """.stripMargin)
+ }
+ }),
+ TestTable("src_thrift", () => {
+ import org.apache.thrift.protocol.TBinaryProtocol
+ import org.apache.hadoop.hive.serde2.thrift.test.Complex
+ import org.apache.hadoop.hive.serde2.thrift.ThriftDeserializer
+ import org.apache.hadoop.mapred.SequenceFileInputFormat
+ import org.apache.hadoop.mapred.SequenceFileOutputFormat
+
+ val srcThrift = new Table("default", "src_thrift")
+ srcThrift.setFields(Nil)
+ srcThrift.setInputFormatClass(classOf[SequenceFileInputFormat[_,_]].getName)
+ // In Hive, SequenceFileOutputFormat will be substituted by HiveSequenceFileOutputFormat.
+ srcThrift.setOutputFormatClass(classOf[SequenceFileOutputFormat[_,_]].getName)
+ srcThrift.setSerializationLib(classOf[ThriftDeserializer].getName)
+ srcThrift.setSerdeParam("serialization.class", classOf[Complex].getName)
+ srcThrift.setSerdeParam("serialization.format", classOf[TBinaryProtocol].getName)
+ catalog.client.createTable(srcThrift)
+
+
+ runSqlHive(
+ s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/complex.seq")}' INTO TABLE src_thrift")
+ }),
+ TestTable("serdeins",
+ s"""CREATE TABLE serdeins (key INT, value STRING)
+ |ROW FORMAT SERDE '${classOf[LazySimpleSerDe].getCanonicalName}'
+ |WITH SERDEPROPERTIES ('field.delim'='\\t')
+ """.stripMargin.cmd,
+ "INSERT OVERWRITE TABLE serdeins SELECT * FROM src".cmd),
+ TestTable("sales",
+ s"""CREATE TABLE IF NOT EXISTS sales (key STRING, value INT)
+ |ROW FORMAT SERDE '${classOf[RegexSerDe].getCanonicalName}'
+ |WITH SERDEPROPERTIES ("input.regex" = "([^ ]*)\t([^ ]*)")
+ """.stripMargin.cmd,
+ s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/sales.txt")}' INTO TABLE sales".cmd),
+ TestTable("episodes",
+ s"""CREATE TABLE episodes (title STRING, air_date STRING, doctor INT)
+ |ROW FORMAT SERDE '${classOf[AvroSerDe].getCanonicalName}'
+ |STORED AS
+ |INPUTFORMAT '${classOf[AvroContainerInputFormat].getCanonicalName}'
+ |OUTPUTFORMAT '${classOf[AvroContainerOutputFormat].getCanonicalName}'
+ |TBLPROPERTIES (
+ | 'avro.schema.literal'='{
+ | "type": "record",
+ | "name": "episodes",
+ | "namespace": "testing.hive.avro.serde",
+ | "fields": [
+ | {
+ | "name": "title",
+ | "type": "string",
+ | "doc": "episode title"
+ | },
+ | {
+ | "name": "air_date",
+ | "type": "string",
+ | "doc": "initial date"
+ | },
+ | {
+ | "name": "doctor",
+ | "type": "int",
+ | "doc": "main actor playing the Doctor in episode"
+ | }
+ | ]
+ | }'
+ |)
+ """.stripMargin.cmd,
+ s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/episodes.avro")}' INTO TABLE episodes".cmd
+ )
+ )
+
+ hiveQTestUtilTables.foreach(registerTestTable)
+
+ private val loadedTables = new collection.mutable.HashSet[String]
+
+ def loadTestTable(name: String) {
+ if (!(loadedTables contains name)) {
+ // Marks the table as loaded first to prevent infite mutually recursive table loading.
+ loadedTables += name
+ logger.info(s"Loading test table $name")
+ val createCmds =
+ testTables.get(name).map(_.commands).getOrElse(sys.error(s"Unknown test table $name"))
+ createCmds.foreach(_())
+ }
+ }
+
+ /**
+ * Records the UDFs present when the server starts, so we can delete ones that are created by
+ * tests.
+ */
+ protected val originalUdfs: JavaSet[String] = FunctionRegistry.getFunctionNames
+
+ /**
+ * Resets the test instance by deleting any tables that have been created.
+ * TODO: also clear out UDFs, views, etc.
+ */
+ def reset() {
+ try {
+ // HACK: Hive is too noisy by default.
+ org.apache.log4j.LogManager.getCurrentLoggers.foreach { logger =>
+ logger.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN)
+ }
+
+ // It is important that we RESET first as broken hooks that might have been set could break
+ // other sql exec here.
+ runSqlHive("RESET")
+ // For some reason, RESET does not reset the following variables...
+ runSqlHive("set datanucleus.cache.collections=true")
+ runSqlHive("set datanucleus.cache.collections.lazy=true")
+ // Lots of tests fail if we do not change the partition whitelist from the default.
+ runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*")
+
+ loadedTables.clear()
+ catalog.client.getAllTables("default").foreach { t =>
+ logger.debug(s"Deleting table $t")
+ val table = catalog.client.getTable("default", t)
+
+ catalog.client.getIndexes("default", t, 255).foreach { index =>
+ catalog.client.dropIndex("default", t, index.getIndexName, true)
+ }
+
+ if (!table.isIndexTable) {
+ catalog.client.dropTable("default", t)
+ }
+ }
+
+ catalog.client.getAllDatabases.filterNot(_ == "default").foreach { db =>
+ logger.debug(s"Dropping Database: $db")
+ catalog.client.dropDatabase(db, true, false, true)
+ }
+
+ FunctionRegistry.getFunctionNames.filterNot(originalUdfs.contains(_)).foreach { udfName =>
+ FunctionRegistry.unregisterTemporaryUDF(udfName)
+ }
+
+ configure()
+
+ runSqlHive("USE default")
+
+ // Just loading src makes a lot of tests pass. This is because some tests do something like
+ // drop an index on src at the beginning. Since we just pass DDL to hive this bypasses our
+ // Analyzer and thus the test table auto-loading mechanism.
+ // Remove after we handle more DDL operations natively.
+ loadTestTable("src")
+ loadTestTable("srcpart")
+ } catch {
+ case e: Exception =>
+ logger.error(s"FATAL ERROR: Failed to reset TestDB state. $e")
+ // At this point there is really no reason to continue, but the test framework traps exits.
+ // So instead we just pause forever so that at least the developer can see where things
+ // started to go wrong.
+ Thread.sleep(100000)
+ }
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala
new file mode 100644
index 0000000000..d20fd87f34
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala
@@ -0,0 +1,356 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+
+import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar}
+import org.apache.hadoop.hive.metastore.MetaStoreUtils
+import org.apache.hadoop.hive.ql.Context
+import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Hive}
+import org.apache.hadoop.hive.ql.plan.{TableDesc, FileSinkDesc}
+import org.apache.hadoop.hive.serde2.Serializer
+import org.apache.hadoop.hive.serde2.objectinspector._
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveDecimalObjectInspector
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveVarcharObjectInspector
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.Writable
+import org.apache.hadoop.mapred._
+
+import catalyst.expressions._
+import catalyst.types.{BooleanType, DataType}
+import org.apache.spark.{TaskContext, SparkException}
+import catalyst.expressions.Cast
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.execution._
+
+import scala.Some
+import scala.collection.immutable.ListMap
+
+/* Implicits */
+import scala.collection.JavaConversions._
+
+/**
+ * The Hive table scan operator. Column and partition pruning are both handled.
+ *
+ * @constructor
+ * @param attributes Attributes to be fetched from the Hive table.
+ * @param relation The Hive table be be scanned.
+ * @param partitionPruningPred An optional partition pruning predicate for partitioned table.
+ */
+case class HiveTableScan(
+ attributes: Seq[Attribute],
+ relation: MetastoreRelation,
+ partitionPruningPred: Option[Expression])(
+ @transient val sc: HiveContext)
+ extends LeafNode
+ with HiveInspectors {
+
+ require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned,
+ "Partition pruning predicates only supported for partitioned tables.")
+
+ // Bind all partition key attribute references in the partition pruning predicate for later
+ // evaluation.
+ private val boundPruningPred = partitionPruningPred.map { pred =>
+ require(
+ pred.dataType == BooleanType,
+ s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.")
+
+ BindReferences.bindReference(pred, relation.partitionKeys)
+ }
+
+ @transient
+ val hadoopReader = new HadoopTableReader(relation.tableDesc, sc)
+
+ /**
+ * The hive object inspector for this table, which can be used to extract values from the
+ * serialized row representation.
+ */
+ @transient
+ lazy val objectInspector =
+ relation.tableDesc.getDeserializer.getObjectInspector.asInstanceOf[StructObjectInspector]
+
+ /**
+ * Functions that extract the requested attributes from the hive output. Partitioned values are
+ * casted from string to its declared data type.
+ */
+ @transient
+ protected lazy val attributeFunctions: Seq[(Any, Array[String]) => Any] = {
+ attributes.map { a =>
+ val ordinal = relation.partitionKeys.indexOf(a)
+ if (ordinal >= 0) {
+ (_: Any, partitionKeys: Array[String]) => {
+ val value = partitionKeys(ordinal)
+ val dataType = relation.partitionKeys(ordinal).dataType
+ castFromString(value, dataType)
+ }
+ } else {
+ val ref = objectInspector.getAllStructFieldRefs
+ .find(_.getFieldName == a.name)
+ .getOrElse(sys.error(s"Can't find attribute $a"))
+ (row: Any, _: Array[String]) => {
+ val data = objectInspector.getStructFieldData(row, ref)
+ unwrapData(data, ref.getFieldObjectInspector)
+ }
+ }
+ }
+ }
+
+ private def castFromString(value: String, dataType: DataType) = {
+ Cast(Literal(value), dataType).apply(null)
+ }
+
+ @transient
+ def inputRdd = if (!relation.hiveQlTable.isPartitioned) {
+ hadoopReader.makeRDDForTable(relation.hiveQlTable)
+ } else {
+ hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions))
+ }
+
+ /**
+ * Prunes partitions not involve the query plan.
+ *
+ * @param partitions All partitions of the relation.
+ * @return Partitions that are involved in the query plan.
+ */
+ private[hive] def prunePartitions(partitions: Seq[HivePartition]) = {
+ boundPruningPred match {
+ case None => partitions
+ case Some(shouldKeep) => partitions.filter { part =>
+ val dataTypes = relation.partitionKeys.map(_.dataType)
+ val castedValues = for ((value, dataType) <- part.getValues.zip(dataTypes)) yield {
+ castFromString(value, dataType)
+ }
+
+ // Only partitioned values are needed here, since the predicate has already been bound to
+ // partition key attribute references.
+ val row = new GenericRow(castedValues.toArray)
+ shouldKeep.apply(row).asInstanceOf[Boolean]
+ }
+ }
+ }
+
+ def execute() = {
+ inputRdd.map { row =>
+ val values = row match {
+ case Array(deserializedRow: AnyRef, partitionKeys: Array[String]) =>
+ attributeFunctions.map(_(deserializedRow, partitionKeys))
+ case deserializedRow: AnyRef =>
+ attributeFunctions.map(_(deserializedRow, Array.empty))
+ }
+ buildRow(values.map {
+ case n: String if n.toLowerCase == "null" => null
+ case varchar: org.apache.hadoop.hive.common.`type`.HiveVarchar => varchar.getValue
+ case decimal: org.apache.hadoop.hive.common.`type`.HiveDecimal =>
+ BigDecimal(decimal.bigDecimalValue)
+ case other => other
+ })
+ }
+ }
+
+ def output = attributes
+}
+
+case class InsertIntoHiveTable(
+ table: MetastoreRelation,
+ partition: Map[String, Option[String]],
+ child: SparkPlan,
+ overwrite: Boolean)
+ (@transient sc: HiveContext)
+ extends UnaryNode {
+
+ val outputClass = newSerializer(table.tableDesc).getSerializedClass
+ @transient private val hiveContext = new Context(sc.hiveconf)
+ @transient private val db = Hive.get(sc.hiveconf)
+
+ private def newSerializer(tableDesc: TableDesc): Serializer = {
+ val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer]
+ serializer.initialize(null, tableDesc.getProperties)
+ serializer
+ }
+
+ override def otherCopyArgs = sc :: Nil
+
+ def output = child.output
+
+ /**
+ * Wraps with Hive types based on object inspector.
+ * TODO: Consolidate all hive OI/data interface code.
+ */
+ protected def wrap(a: (Any, ObjectInspector)): Any = a match {
+ case (s: String, oi: JavaHiveVarcharObjectInspector) => new HiveVarchar(s, s.size)
+ case (bd: BigDecimal, oi: JavaHiveDecimalObjectInspector) =>
+ new HiveDecimal(bd.underlying())
+ case (row: Row, oi: StandardStructObjectInspector) =>
+ val struct = oi.create()
+ row.zip(oi.getAllStructFieldRefs).foreach {
+ case (data, field) =>
+ oi.setStructFieldData(struct, field, wrap(data, field.getFieldObjectInspector))
+ }
+ struct
+ case (s: Seq[_], oi: ListObjectInspector) =>
+ val wrappedSeq = s.map(wrap(_, oi.getListElementObjectInspector))
+ seqAsJavaList(wrappedSeq)
+ case (obj, _) => obj
+ }
+
+ def saveAsHiveFile(
+ rdd: RDD[Writable],
+ valueClass: Class[_],
+ fileSinkConf: FileSinkDesc,
+ conf: JobConf,
+ isCompressed: Boolean) {
+ if (valueClass == null) {
+ throw new SparkException("Output value class not set")
+ }
+ conf.setOutputValueClass(valueClass)
+ if (fileSinkConf.getTableInfo.getOutputFileFormatClassName == null) {
+ throw new SparkException("Output format class not set")
+ }
+ // Doesn't work in Scala 2.9 due to what may be a generics bug
+ // TODO: Should we uncomment this for Scala 2.10?
+ // conf.setOutputFormat(outputFormatClass)
+ conf.set("mapred.output.format.class", fileSinkConf.getTableInfo.getOutputFileFormatClassName)
+ if (isCompressed) {
+ // Please note that isCompressed, "mapred.output.compress", "mapred.output.compression.codec",
+ // and "mapred.output.compression.type" have no impact on ORC because it uses table properties
+ // to store compression information.
+ conf.set("mapred.output.compress", "true")
+ fileSinkConf.setCompressed(true)
+ fileSinkConf.setCompressCodec(conf.get("mapred.output.compression.codec"))
+ fileSinkConf.setCompressType(conf.get("mapred.output.compression.type"))
+ }
+ conf.setOutputCommitter(classOf[FileOutputCommitter])
+ FileOutputFormat.setOutputPath(
+ conf,
+ SparkHiveHadoopWriter.createPathFromString(fileSinkConf.getDirName, conf))
+
+ logger.debug("Saving as hadoop file of type " + valueClass.getSimpleName)
+
+ val writer = new SparkHiveHadoopWriter(conf, fileSinkConf)
+ writer.preSetup()
+
+ def writeToFile(context: TaskContext, iter: Iterator[Writable]) {
+ // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
+ // around by taking a mod. We expect that no task will be attempted 2 billion times.
+ val attemptNumber = (context.attemptId % Int.MaxValue).toInt
+
+ writer.setup(context.stageId, context.partitionId, attemptNumber)
+ writer.open()
+
+ var count = 0
+ while(iter.hasNext) {
+ val record = iter.next()
+ count += 1
+ writer.write(record)
+ }
+
+ writer.close()
+ writer.commit()
+ }
+
+ sc.sparkContext.runJob(rdd, writeToFile _)
+ writer.commitJob()
+ }
+
+ /**
+ * Inserts all the rows in the table into Hive. Row objects are properly serialized with the
+ * `org.apache.hadoop.hive.serde2.SerDe` and the
+ * `org.apache.hadoop.mapred.OutputFormat` provided by the table definition.
+ */
+ def execute() = {
+ val childRdd = child.execute()
+ assert(childRdd != null)
+
+ // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer
+ // instances within the closure, since Serializer is not serializable while TableDesc is.
+ val tableDesc = table.tableDesc
+ val tableLocation = table.hiveQlTable.getDataLocation
+ val tmpLocation = hiveContext.getExternalTmpFileURI(tableLocation)
+ val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false)
+ val rdd = childRdd.mapPartitions { iter =>
+ val serializer = newSerializer(fileSinkConf.getTableInfo)
+ val standardOI = ObjectInspectorUtils
+ .getStandardObjectInspector(
+ fileSinkConf.getTableInfo.getDeserializer.getObjectInspector,
+ ObjectInspectorCopyOption.JAVA)
+ .asInstanceOf[StructObjectInspector]
+
+ iter.map { row =>
+ // Casts Strings to HiveVarchars when necessary.
+ val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector)
+ val mappedRow = row.zip(fieldOIs).map(wrap)
+
+ serializer.serialize(mappedRow.toArray, standardOI)
+ }
+ }
+
+ // ORC stores compression information in table properties. While, there are other formats
+ // (e.g. RCFile) that rely on hadoop configurations to store compression information.
+ val jobConf = new JobConf(sc.hiveconf)
+ saveAsHiveFile(
+ rdd,
+ outputClass,
+ fileSinkConf,
+ jobConf,
+ sc.hiveconf.getBoolean("hive.exec.compress.output", false))
+
+ // TODO: Handle dynamic partitioning.
+ val outputPath = FileOutputFormat.getOutputPath(jobConf)
+ // Have to construct the format of dbname.tablename.
+ val qualifiedTableName = s"${table.databaseName}.${table.tableName}"
+ // TODO: Correctly set holdDDLTime.
+ // In most of the time, we should have holdDDLTime = false.
+ // holdDDLTime will be true when TOK_HOLD_DDLTIME presents in the query as a hint.
+ val holdDDLTime = false
+ if (partition.nonEmpty) {
+ val partitionSpec = partition.map {
+ case (key, Some(value)) => key -> value
+ case (key, None) => key -> "" // Should not reach here right now.
+ }
+ val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols(), partitionSpec)
+ db.validatePartitionNameCharacters(partVals)
+ // inheritTableSpecs is set to true. It should be set to false for a IMPORT query
+ // which is currently considered as a Hive native command.
+ val inheritTableSpecs = true
+ // TODO: Correctly set isSkewedStoreAsSubdir.
+ val isSkewedStoreAsSubdir = false
+ db.loadPartition(
+ outputPath,
+ qualifiedTableName,
+ partitionSpec,
+ overwrite,
+ holdDDLTime,
+ inheritTableSpecs,
+ isSkewedStoreAsSubdir)
+ } else {
+ db.loadTable(
+ outputPath,
+ qualifiedTableName,
+ overwrite,
+ holdDDLTime)
+ }
+
+ // It would be nice to just return the childRdd unchanged so insert operations could be chained,
+ // however for now we return an empty list to simplify compatibility checks with hive, which
+ // does not return anything for insert operations.
+ // TODO: implement hive compatibility as rules.
+ sc.sparkContext.makeRDD(Nil, 1)
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
new file mode 100644
index 0000000000..5e775d6a04
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -0,0 +1,467 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.hadoop.hive.common.`type`.HiveDecimal
+import org.apache.hadoop.hive.serde2.{io => hiveIo}
+import org.apache.hadoop.hive.serde2.objectinspector.primitive._
+import org.apache.hadoop.hive.serde2.objectinspector._
+import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry}
+import org.apache.hadoop.hive.ql.udf.generic._
+import org.apache.hadoop.hive.ql.exec.UDF
+import org.apache.hadoop.{io => hadoopIo}
+
+import catalyst.analysis
+import catalyst.expressions._
+import catalyst.types
+import catalyst.types._
+
+object HiveFunctionRegistry
+ extends analysis.FunctionRegistry with HiveFunctionFactory with HiveInspectors {
+
+ def lookupFunction(name: String, children: Seq[Expression]): Expression = {
+ // We only look it up to see if it exists, but do not include it in the HiveUDF since it is
+ // not always serializable.
+ val functionInfo: FunctionInfo = Option(FunctionRegistry.getFunctionInfo(name)).getOrElse(
+ sys.error(s"Couldn't find function $name"))
+
+ if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
+ val function = createFunction[UDF](name)
+ val method = function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo))
+
+ lazy val expectedDataTypes = method.getParameterTypes.map(javaClassToDataType)
+
+ HiveSimpleUdf(
+ name,
+ children.zip(expectedDataTypes).map { case (e, t) => Cast(e, t) }
+ )
+ } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
+ HiveGenericUdf(name, children)
+ } else if (
+ classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
+ HiveGenericUdaf(name, children)
+
+ } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
+ HiveGenericUdtf(name, Nil, children)
+ } else {
+ sys.error(s"No handler for udf ${functionInfo.getFunctionClass}")
+ }
+ }
+
+ def javaClassToDataType(clz: Class[_]): DataType = clz match {
+ case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType
+ case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType
+ case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType
+ case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType
+ case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType
+ case c: Class[_] if c == classOf[hadoopIo.Text] => StringType
+ case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType
+ case c: Class[_] if c == classOf[hadoopIo.LongWritable] => LongType
+ case c: Class[_] if c == classOf[hadoopIo.FloatWritable] => FloatType
+ case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType
+ case c: Class[_] if c == classOf[java.lang.String] => StringType
+ case c: Class[_] if c == java.lang.Short.TYPE => ShortType
+ case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType
+ case c: Class[_] if c == java.lang.Long.TYPE => LongType
+ case c: Class[_] if c == java.lang.Double.TYPE => DoubleType
+ case c: Class[_] if c == java.lang.Byte.TYPE => ByteType
+ case c: Class[_] if c == java.lang.Float.TYPE => FloatType
+ case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType
+ case c: Class[_] if c == classOf[java.lang.Short] => ShortType
+ case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
+ case c: Class[_] if c == classOf[java.lang.Long] => LongType
+ case c: Class[_] if c == classOf[java.lang.Double] => DoubleType
+ case c: Class[_] if c == classOf[java.lang.Byte] => ByteType
+ case c: Class[_] if c == classOf[java.lang.Float] => FloatType
+ case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType
+ case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType))
+ }
+}
+
+trait HiveFunctionFactory {
+ def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name)
+ def getFunctionClass(name: String) = getFunctionInfo(name).getFunctionClass
+ def createFunction[UDFType](name: String) =
+ getFunctionClass(name).newInstance.asInstanceOf[UDFType]
+
+ /** Converts hive types to native catalyst types. */
+ def unwrap(a: Any): Any = a match {
+ case null => null
+ case i: hadoopIo.IntWritable => i.get
+ case t: hadoopIo.Text => t.toString
+ case l: hadoopIo.LongWritable => l.get
+ case d: hadoopIo.DoubleWritable => d.get()
+ case d: hiveIo.DoubleWritable => d.get
+ case s: hiveIo.ShortWritable => s.get
+ case b: hadoopIo.BooleanWritable => b.get()
+ case b: hiveIo.ByteWritable => b.get
+ case list: java.util.List[_] => list.map(unwrap)
+ case map: java.util.Map[_,_] => map.map { case (k, v) => (unwrap(k), unwrap(v)) }.toMap
+ case array: Array[_] => array.map(unwrap).toSeq
+ case p: java.lang.Short => p
+ case p: java.lang.Long => p
+ case p: java.lang.Float => p
+ case p: java.lang.Integer => p
+ case p: java.lang.Double => p
+ case p: java.lang.Byte => p
+ case p: java.lang.Boolean => p
+ case str: String => str
+ }
+}
+
+abstract class HiveUdf
+ extends Expression with Logging with HiveFunctionFactory {
+ self: Product =>
+
+ type UDFType
+ type EvaluatedType = Any
+
+ val name: String
+
+ def nullable = true
+ def references = children.flatMap(_.references).toSet
+
+ // FunctionInfo is not serializable so we must look it up here again.
+ lazy val functionInfo = getFunctionInfo(name)
+ lazy val function = createFunction[UDFType](name)
+
+ override def toString = s"${nodeName}#${functionInfo.getDisplayName}(${children.mkString(",")})"
+}
+
+case class HiveSimpleUdf(name: String, children: Seq[Expression]) extends HiveUdf {
+ import HiveFunctionRegistry._
+ type UDFType = UDF
+
+ @transient
+ protected lazy val method =
+ function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo))
+
+ @transient
+ lazy val dataType = javaClassToDataType(method.getReturnType)
+
+ protected lazy val wrappers: Array[(Any) => AnyRef] = method.getParameterTypes.map { argClass =>
+ val primitiveClasses = Seq(
+ Integer.TYPE, classOf[java.lang.Integer], classOf[java.lang.String], java.lang.Double.TYPE,
+ classOf[java.lang.Double], java.lang.Long.TYPE, classOf[java.lang.Long],
+ classOf[HiveDecimal], java.lang.Byte.TYPE, classOf[java.lang.Byte]
+ )
+ val matchingConstructor = argClass.getConstructors.find { c =>
+ c.getParameterTypes.size == 1 && primitiveClasses.contains(c.getParameterTypes.head)
+ }
+
+ val constructor = matchingConstructor.getOrElse(
+ sys.error(s"No matching wrapper found, options: ${argClass.getConstructors.toSeq}."))
+
+ (a: Any) => {
+ logger.debug(
+ s"Wrapping $a of type ${if (a == null) "null" else a.getClass.getName} using $constructor.")
+ // We must make sure that primitives get boxed java style.
+ if (a == null) {
+ null
+ } else {
+ constructor.newInstance(a match {
+ case i: Int => i: java.lang.Integer
+ case bd: BigDecimal => new HiveDecimal(bd.underlying())
+ case other: AnyRef => other
+ }).asInstanceOf[AnyRef]
+ }
+ }
+ }
+
+ // TODO: Finish input output types.
+ override def apply(input: Row): Any = {
+ val evaluatedChildren = children.map(_.apply(input))
+ // Wrap the function arguments in the expected types.
+ val args = evaluatedChildren.zip(wrappers).map {
+ case (arg, wrapper) => wrapper(arg)
+ }
+
+ // Invoke the udf and unwrap the result.
+ unwrap(method.invoke(function, args: _*))
+ }
+}
+
+case class HiveGenericUdf(
+ name: String,
+ children: Seq[Expression]) extends HiveUdf with HiveInspectors {
+ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._
+ type UDFType = GenericUDF
+
+ @transient
+ protected lazy val argumentInspectors = children.map(_.dataType).map(toInspector)
+
+ @transient
+ protected lazy val returnInspector = function.initialize(argumentInspectors.toArray)
+
+ val dataType: DataType = inspectorToDataType(returnInspector)
+
+ override def apply(input: Row): Any = {
+ returnInspector // Make sure initialized.
+ val args = children.map { v =>
+ new DeferredObject {
+ override def prepare(i: Int) = {}
+ override def get(): AnyRef = wrap(v.apply(input))
+ }
+ }.toArray
+ unwrap(function.evaluate(args))
+ }
+}
+
+trait HiveInspectors {
+
+ def unwrapData(data: Any, oi: ObjectInspector): Any = oi match {
+ case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data)
+ case li: ListObjectInspector =>
+ Option(li.getList(data))
+ .map(_.map(unwrapData(_, li.getListElementObjectInspector)).toSeq)
+ .orNull
+ case mi: MapObjectInspector =>
+ Option(mi.getMap(data)).map(
+ _.map {
+ case (k,v) =>
+ (unwrapData(k, mi.getMapKeyObjectInspector),
+ unwrapData(v, mi.getMapValueObjectInspector))
+ }.toMap).orNull
+ case si: StructObjectInspector =>
+ val allRefs = si.getAllStructFieldRefs
+ new GenericRow(
+ allRefs.map(r =>
+ unwrapData(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray)
+ }
+
+ /** Converts native catalyst types to the types expected by Hive */
+ def wrap(a: Any): AnyRef = a match {
+ case s: String => new hadoopIo.Text(s)
+ case i: Int => i: java.lang.Integer
+ case b: Boolean => b: java.lang.Boolean
+ case d: Double => d: java.lang.Double
+ case l: Long => l: java.lang.Long
+ case l: Short => l: java.lang.Short
+ case l: Byte => l: java.lang.Byte
+ case s: Seq[_] => seqAsJavaList(s.map(wrap))
+ case m: Map[_,_] =>
+ mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) })
+ case null => null
+ }
+
+ def toInspector(dataType: DataType): ObjectInspector = dataType match {
+ case ArrayType(tpe) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe))
+ case MapType(keyType, valueType) =>
+ ObjectInspectorFactory.getStandardMapObjectInspector(
+ toInspector(keyType), toInspector(valueType))
+ case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector
+ case IntegerType => PrimitiveObjectInspectorFactory.javaIntObjectInspector
+ case DoubleType => PrimitiveObjectInspectorFactory.javaDoubleObjectInspector
+ case BooleanType => PrimitiveObjectInspectorFactory.javaBooleanObjectInspector
+ case LongType => PrimitiveObjectInspectorFactory.javaLongObjectInspector
+ case FloatType => PrimitiveObjectInspectorFactory.javaFloatObjectInspector
+ case ShortType => PrimitiveObjectInspectorFactory.javaShortObjectInspector
+ case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector
+ case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector
+ case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector
+ }
+
+ def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match {
+ case s: StructObjectInspector =>
+ StructType(s.getAllStructFieldRefs.map(f => {
+ types.StructField(
+ f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true)
+ }))
+ case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector))
+ case m: MapObjectInspector =>
+ MapType(
+ inspectorToDataType(m.getMapKeyObjectInspector),
+ inspectorToDataType(m.getMapValueObjectInspector))
+ case _: WritableStringObjectInspector => StringType
+ case _: JavaStringObjectInspector => StringType
+ case _: WritableIntObjectInspector => IntegerType
+ case _: JavaIntObjectInspector => IntegerType
+ case _: WritableDoubleObjectInspector => DoubleType
+ case _: JavaDoubleObjectInspector => DoubleType
+ case _: WritableBooleanObjectInspector => BooleanType
+ case _: JavaBooleanObjectInspector => BooleanType
+ case _: WritableLongObjectInspector => LongType
+ case _: JavaLongObjectInspector => LongType
+ case _: WritableShortObjectInspector => ShortType
+ case _: JavaShortObjectInspector => ShortType
+ case _: WritableByteObjectInspector => ByteType
+ case _: JavaByteObjectInspector => ByteType
+ }
+
+ implicit class typeInfoConversions(dt: DataType) {
+ import org.apache.hadoop.hive.serde2.typeinfo._
+ import TypeInfoFactory._
+
+ def toTypeInfo: TypeInfo = dt match {
+ case BinaryType => binaryTypeInfo
+ case BooleanType => booleanTypeInfo
+ case ByteType => byteTypeInfo
+ case DoubleType => doubleTypeInfo
+ case FloatType => floatTypeInfo
+ case IntegerType => intTypeInfo
+ case LongType => longTypeInfo
+ case ShortType => shortTypeInfo
+ case StringType => stringTypeInfo
+ case DecimalType => decimalTypeInfo
+ case NullType => voidTypeInfo
+ }
+ }
+}
+
+case class HiveGenericUdaf(
+ name: String,
+ children: Seq[Expression]) extends AggregateExpression
+ with HiveInspectors
+ with HiveFunctionFactory {
+
+ type UDFType = AbstractGenericUDAFResolver
+
+ protected lazy val resolver: AbstractGenericUDAFResolver = createFunction(name)
+
+ protected lazy val objectInspector = {
+ resolver.getEvaluator(children.map(_.dataType.toTypeInfo).toArray)
+ .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray)
+ }
+
+ protected lazy val inspectors = children.map(_.dataType).map(toInspector)
+
+ def dataType: DataType = inspectorToDataType(objectInspector)
+
+ def nullable: Boolean = true
+
+ def references: Set[Attribute] = children.map(_.references).flatten.toSet
+
+ override def toString = s"$nodeName#$name(${children.mkString(",")})"
+
+ def newInstance = new HiveUdafFunction(name, children, this)
+}
+
+/**
+ * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a
+ * [[catalyst.expressions.Generator Generator]]. Note that the semantics of Generators do not allow
+ * Generators to maintain state in between input rows. Thus UDTFs that rely on partitioning
+ * dependent operations like calls to `close()` before producing output will not operate the same as
+ * in Hive. However, in practice this should not affect compatibility for most sane UDTFs
+ * (e.g. explode or GenericUDTFParseUrlTuple).
+ *
+ * Operators that require maintaining state in between input rows should instead be implemented as
+ * user defined aggregations, which have clean semantics even in a partitioned execution.
+ */
+case class HiveGenericUdtf(
+ name: String,
+ aliasNames: Seq[String],
+ children: Seq[Expression])
+ extends Generator with HiveInspectors with HiveFunctionFactory {
+
+ override def references = children.flatMap(_.references).toSet
+
+ @transient
+ protected lazy val function: GenericUDTF = createFunction(name)
+
+ protected lazy val inputInspectors = children.map(_.dataType).map(toInspector)
+
+ protected lazy val outputInspectors = {
+ val structInspector = function.initialize(inputInspectors.toArray)
+ structInspector.getAllStructFieldRefs.map(_.getFieldObjectInspector)
+ }
+
+ protected lazy val outputDataTypes = outputInspectors.map(inspectorToDataType)
+
+ override protected def makeOutput() = {
+ // Use column names when given, otherwise c_1, c_2, ... c_n.
+ if (aliasNames.size == outputDataTypes.size) {
+ aliasNames.zip(outputDataTypes).map {
+ case (attrName, attrDataType) =>
+ AttributeReference(attrName, attrDataType, nullable = true)()
+ }
+ } else {
+ outputDataTypes.zipWithIndex.map {
+ case (attrDataType, i) =>
+ AttributeReference(s"c_$i", attrDataType, nullable = true)()
+ }
+ }
+ }
+
+ override def apply(input: Row): TraversableOnce[Row] = {
+ outputInspectors // Make sure initialized.
+
+ val inputProjection = new Projection(children)
+ val collector = new UDTFCollector
+ function.setCollector(collector)
+
+ val udtInput = inputProjection(input).map(wrap).toArray
+ function.process(udtInput)
+ collector.collectRows()
+ }
+
+ protected class UDTFCollector extends Collector {
+ var collected = new ArrayBuffer[Row]
+
+ override def collect(input: java.lang.Object) {
+ // We need to clone the input here because implementations of
+ // GenericUDTF reuse the same object. Luckily they are always an array, so
+ // it is easy to clone.
+ collected += new GenericRow(input.asInstanceOf[Array[_]].map(unwrap))
+ }
+
+ def collectRows() = {
+ val toCollect = collected
+ collected = new ArrayBuffer[Row]
+ toCollect
+ }
+ }
+
+ override def toString() = s"$nodeName#$name(${children.mkString(",")})"
+}
+
+case class HiveUdafFunction(
+ functionName: String,
+ exprs: Seq[Expression],
+ base: AggregateExpression)
+ extends AggregateFunction
+ with HiveInspectors
+ with HiveFunctionFactory {
+
+ def this() = this(null, null, null)
+
+ private val resolver = createFunction[AbstractGenericUDAFResolver](functionName)
+
+ private val inspectors = exprs.map(_.dataType).map(toInspector).toArray
+
+ private val function = resolver.getEvaluator(exprs.map(_.dataType.toTypeInfo).toArray)
+
+ private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors)
+
+ // Cast required to avoid type inference selecting a deprecated Hive API.
+ private val buffer =
+ function.getNewAggregationBuffer.asInstanceOf[GenericUDAFEvaluator.AbstractAggregationBuffer]
+
+ override def apply(input: Row): Any = unwrapData(function.evaluate(buffer), returnInspector)
+
+ @transient
+ val inputProjection = new Projection(exprs)
+
+ def update(input: Row): Unit = {
+ val inputs = inputProjection(input).asInstanceOf[Seq[AnyRef]].toArray
+ function.iterate(buffer, inputs)
+ }
+}