aboutsummaryrefslogtreecommitdiff
path: root/sql/hive
diff options
context:
space:
mode:
Diffstat (limited to 'sql/hive')
-rw-r--r--sql/hive/pom.xml81
-rw-r--r--sql/hive/src/main/scala/org/apache/hadoop/mapred/SparkHadoopWriter.scala198
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala287
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala246
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala966
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala164
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/ScriptTransformation.scala76
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala243
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala341
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala356
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala467
-rw-r--r--sql/hive/src/test/resources/log4j.properties47
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala126
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala38
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala379
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala708
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala70
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala144
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala65
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala33
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala32
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala164
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala161
23 files changed, 5392 insertions, 0 deletions
diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml
new file mode 100644
index 0000000000..7b5ea98f27
--- /dev/null
+++ b/sql/hive/pom.xml
@@ -0,0 +1,81 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-parent</artifactId>
+ <version>1.0.0-SNAPSHOT</version>
+ <relativePath>../../pom.xml</relativePath>
+ </parent>
+
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-hive_2.10</artifactId>
+ <packaging>jar</packaging>
+ <name>Spark Project Hive</name>
+ <url>http://spark.apache.org/</url>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-core_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-sql_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hive</groupId>
+ <artifactId>hive-metastore</artifactId>
+ <version>${hive.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hive</groupId>
+ <artifactId>hive-exec</artifactId>
+ <version>${hive.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hive</groupId>
+ <artifactId>hive-serde</artifactId>
+ <version>${hive.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.scalacheck</groupId>
+ <artifactId>scalacheck_${scala.binary.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+ <build>
+ <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
+ <plugins>
+ <plugin>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest-maven-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
+</project>
diff --git a/sql/hive/src/main/scala/org/apache/hadoop/mapred/SparkHadoopWriter.scala b/sql/hive/src/main/scala/org/apache/hadoop/mapred/SparkHadoopWriter.scala
new file mode 100644
index 0000000000..08d390e887
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/hadoop/mapred/SparkHadoopWriter.scala
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mapred
+
+import java.io.IOException
+import java.text.NumberFormat
+import java.util.Date
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.Writable
+
+import org.apache.spark.Logging
+import org.apache.spark.SerializableWritable
+
+import org.apache.hadoop.hive.ql.exec.{Utilities, FileSinkOperator}
+import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat}
+import org.apache.hadoop.hive.ql.plan.FileSinkDesc
+
+/**
+ * Internal helper class that saves an RDD using a Hive OutputFormat.
+ * It is based on [[SparkHadoopWriter]].
+ */
+protected[apache]
+class SparkHiveHadoopWriter(
+ @transient jobConf: JobConf,
+ fileSinkConf: FileSinkDesc)
+ extends Logging
+ with SparkHadoopMapRedUtil
+ with Serializable {
+
+ private val now = new Date()
+ private val conf = new SerializableWritable(jobConf)
+
+ private var jobID = 0
+ private var splitID = 0
+ private var attemptID = 0
+ private var jID: SerializableWritable[JobID] = null
+ private var taID: SerializableWritable[TaskAttemptID] = null
+
+ @transient private var writer: FileSinkOperator.RecordWriter = null
+ @transient private var format: HiveOutputFormat[AnyRef, Writable] = null
+ @transient private var committer: OutputCommitter = null
+ @transient private var jobContext: JobContext = null
+ @transient private var taskContext: TaskAttemptContext = null
+
+ def preSetup() {
+ setIDs(0, 0, 0)
+ setConfParams()
+
+ val jCtxt = getJobContext()
+ getOutputCommitter().setupJob(jCtxt)
+ }
+
+
+ def setup(jobid: Int, splitid: Int, attemptid: Int) {
+ setIDs(jobid, splitid, attemptid)
+ setConfParams()
+ }
+
+ def open() {
+ val numfmt = NumberFormat.getInstance()
+ numfmt.setMinimumIntegerDigits(5)
+ numfmt.setGroupingUsed(false)
+
+ val extension = Utilities.getFileExtension(
+ conf.value,
+ fileSinkConf.getCompressed,
+ getOutputFormat())
+
+ val outputName = "part-" + numfmt.format(splitID) + extension
+ val path = FileOutputFormat.getTaskOutputPath(conf.value, outputName)
+
+ getOutputCommitter().setupTask(getTaskContext())
+ writer = HiveFileFormatUtils.getHiveRecordWriter(
+ conf.value,
+ fileSinkConf.getTableInfo,
+ conf.value.getOutputValueClass.asInstanceOf[Class[Writable]],
+ fileSinkConf,
+ path,
+ null)
+ }
+
+ def write(value: Writable) {
+ if (writer != null) {
+ writer.write(value)
+ } else {
+ throw new IOException("Writer is null, open() has not been called")
+ }
+ }
+
+ def close() {
+ // Seems the boolean value passed into close does not matter.
+ writer.close(false)
+ }
+
+ def commit() {
+ val taCtxt = getTaskContext()
+ val cmtr = getOutputCommitter()
+ if (cmtr.needsTaskCommit(taCtxt)) {
+ try {
+ cmtr.commitTask(taCtxt)
+ logInfo (taID + ": Committed")
+ } catch {
+ case e: IOException => {
+ logError("Error committing the output of task: " + taID.value, e)
+ cmtr.abortTask(taCtxt)
+ throw e
+ }
+ }
+ } else {
+ logWarning ("No need to commit output of task: " + taID.value)
+ }
+ }
+
+ def commitJob() {
+ // always ? Or if cmtr.needsTaskCommit ?
+ val cmtr = getOutputCommitter()
+ cmtr.commitJob(getJobContext())
+ }
+
+ // ********* Private Functions *********
+
+ private def getOutputFormat(): HiveOutputFormat[AnyRef,Writable] = {
+ if (format == null) {
+ format = conf.value.getOutputFormat()
+ .asInstanceOf[HiveOutputFormat[AnyRef,Writable]]
+ }
+ format
+ }
+
+ private def getOutputCommitter(): OutputCommitter = {
+ if (committer == null) {
+ committer = conf.value.getOutputCommitter
+ }
+ committer
+ }
+
+ private def getJobContext(): JobContext = {
+ if (jobContext == null) {
+ jobContext = newJobContext(conf.value, jID.value)
+ }
+ jobContext
+ }
+
+ private def getTaskContext(): TaskAttemptContext = {
+ if (taskContext == null) {
+ taskContext = newTaskAttemptContext(conf.value, taID.value)
+ }
+ taskContext
+ }
+
+ private def setIDs(jobid: Int, splitid: Int, attemptid: Int) {
+ jobID = jobid
+ splitID = splitid
+ attemptID = attemptid
+
+ jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobid))
+ taID = new SerializableWritable[TaskAttemptID](
+ new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID))
+ }
+
+ private def setConfParams() {
+ conf.value.set("mapred.job.id", jID.value.toString)
+ conf.value.set("mapred.tip.id", taID.value.getTaskID.toString)
+ conf.value.set("mapred.task.id", taID.value.toString)
+ conf.value.setBoolean("mapred.task.is.map", true)
+ conf.value.setInt("mapred.task.partition", splitID)
+ }
+}
+
+object SparkHiveHadoopWriter {
+ def createPathFromString(path: String, conf: JobConf): Path = {
+ if (path == null) {
+ throw new IllegalArgumentException("Output path is null")
+ }
+ val outputPath = new Path(path)
+ val fs = outputPath.getFileSystem(conf)
+ if (outputPath == null || fs == null) {
+ throw new IllegalArgumentException("Incorrectly formatted output path")
+ }
+ outputPath.makeQualified(fs)
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
new file mode 100644
index 0000000000..4aad876cc0
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -0,0 +1,287 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+
+import java.io.{PrintStream, InputStreamReader, BufferedReader, File}
+import java.util.{ArrayList => JArrayList}
+import scala.language.implicitConversions
+
+import org.apache.spark.SparkContext
+import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.ql.session.SessionState
+import org.apache.hadoop.hive.ql.processors.{CommandProcessorResponse, CommandProcessorFactory}
+import org.apache.hadoop.hive.ql.processors.CommandProcessor
+import org.apache.hadoop.hive.ql.Driver
+import org.apache.spark.rdd.RDD
+
+import catalyst.analysis.{Analyzer, OverrideCatalog}
+import catalyst.expressions.GenericRow
+import catalyst.plans.logical.{BaseRelation, LogicalPlan, NativeCommand, ExplainCommand}
+import catalyst.types._
+
+import org.apache.spark.sql.execution._
+
+import scala.collection.JavaConversions._
+
+/**
+ * Starts up an instance of hive where metadata is stored locally. An in-process metadata data is
+ * created with data stored in ./metadata. Warehouse data is stored in in ./warehouse.
+ */
+class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) {
+
+ lazy val metastorePath = new File("metastore").getCanonicalPath
+ lazy val warehousePath: String = new File("warehouse").getCanonicalPath
+
+ /** Sets up the system initially or after a RESET command */
+ protected def configure() {
+ // TODO: refactor this so we can work with other databases.
+ runSqlHive(
+ s"set javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$metastorePath;create=true")
+ runSqlHive("set hive.metastore.warehouse.dir=" + warehousePath)
+ }
+
+ configure() // Must be called before initializing the catalog below.
+}
+
+/**
+ * An instance of the Spark SQL execution engine that integrates with data stored in Hive.
+ * Configuration for Hive is read from hive-site.xml on the classpath.
+ */
+class HiveContext(sc: SparkContext) extends SQLContext(sc) {
+ self =>
+
+ override def parseSql(sql: String): LogicalPlan = HiveQl.parseSql(sql)
+ override def executePlan(plan: LogicalPlan): this.QueryExecution =
+ new this.QueryExecution { val logical = plan }
+
+ // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur.
+ @transient
+ protected val outputBuffer = new java.io.OutputStream {
+ var pos: Int = 0
+ var buffer = new Array[Int](10240)
+ def write(i: Int): Unit = {
+ buffer(pos) = i
+ pos = (pos + 1) % buffer.size
+ }
+
+ override def toString = {
+ val (end, start) = buffer.splitAt(pos)
+ val input = new java.io.InputStream {
+ val iterator = (start ++ end).iterator
+
+ def read(): Int = if (iterator.hasNext) iterator.next else -1
+ }
+ val reader = new BufferedReader(new InputStreamReader(input))
+ val stringBuilder = new StringBuilder
+ var line = reader.readLine()
+ while(line != null) {
+ stringBuilder.append(line)
+ stringBuilder.append("\n")
+ line = reader.readLine()
+ }
+ stringBuilder.toString()
+ }
+ }
+
+ @transient protected[hive] lazy val hiveconf = new HiveConf(classOf[SessionState])
+ @transient protected[hive] lazy val sessionState = new SessionState(hiveconf)
+
+ sessionState.err = new PrintStream(outputBuffer, true, "UTF-8")
+ sessionState.out = new PrintStream(outputBuffer, true, "UTF-8")
+
+ /* A catalyst metadata catalog that points to the Hive Metastore. */
+ @transient
+ override lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog
+
+ /* An analyzer that uses the Hive metastore. */
+ @transient
+ override lazy val analyzer = new Analyzer(catalog, HiveFunctionRegistry, caseSensitive = false)
+
+ def tables: Seq[BaseRelation] = {
+ // TODO: Move this functionallity to Catalog. Make client protected.
+ val allTables = catalog.client.getAllTables("default")
+ allTables.map(catalog.lookupRelation(None, _, None)).collect { case b: BaseRelation => b }
+ }
+
+ /**
+ * Runs the specified SQL query using Hive.
+ */
+ protected def runSqlHive(sql: String): Seq[String] = {
+ val maxResults = 100000
+ val results = runHive(sql, 100000)
+ // It is very confusing when you only get back some of the results...
+ if (results.size == maxResults) sys.error("RESULTS POSSIBLY TRUNCATED")
+ results
+ }
+
+ // TODO: Move this.
+
+ SessionState.start(sessionState)
+
+ /**
+ * Execute the command using Hive and return the results as a sequence. Each element
+ * in the sequence is one row.
+ */
+ protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = {
+ try {
+ val cmd_trimmed: String = cmd.trim()
+ val tokens: Array[String] = cmd_trimmed.split("\\s+")
+ val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim()
+ val proc: CommandProcessor = CommandProcessorFactory.get(tokens(0), hiveconf)
+
+ SessionState.start(sessionState)
+
+ if (proc.isInstanceOf[Driver]) {
+ val driver: Driver = proc.asInstanceOf[Driver]
+ driver.init()
+
+ val results = new JArrayList[String]
+ val response: CommandProcessorResponse = driver.run(cmd)
+ // Throw an exception if there is an error in query processing.
+ if (response.getResponseCode != 0) {
+ driver.destroy()
+ throw new QueryExecutionException(response.getErrorMessage)
+ }
+ driver.setMaxRows(maxRows)
+ driver.getResults(results)
+ driver.destroy()
+ results
+ } else {
+ sessionState.out.println(tokens(0) + " " + cmd_1)
+ Seq(proc.run(cmd_1).getResponseCode.toString)
+ }
+ } catch {
+ case e: Exception =>
+ logger.error(
+ s"""
+ |======================
+ |HIVE FAILURE OUTPUT
+ |======================
+ |${outputBuffer.toString}
+ |======================
+ |END HIVE FAILURE OUTPUT
+ |======================
+ """.stripMargin)
+ throw e
+ }
+ }
+
+ @transient
+ val hivePlanner = new SparkPlanner with HiveStrategies {
+ val hiveContext = self
+
+ override val strategies: Seq[Strategy] = Seq(
+ TopK,
+ ColumnPrunings,
+ PartitionPrunings,
+ HiveTableScans,
+ DataSinks,
+ Scripts,
+ PartialAggregation,
+ SparkEquiInnerJoin,
+ BasicOperators,
+ CartesianProduct,
+ BroadcastNestedLoopJoin
+ )
+ }
+
+ @transient
+ override val planner = hivePlanner
+
+ @transient
+ protected lazy val emptyResult =
+ sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1)
+
+ /** Extends QueryExecution with hive specific features. */
+ abstract class QueryExecution extends super.QueryExecution {
+ // TODO: Create mixin for the analyzer instead of overriding things here.
+ override lazy val optimizedPlan =
+ optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed)))
+
+ // TODO: We are loosing schema here.
+ override lazy val toRdd: RDD[Row] =
+ analyzed match {
+ case NativeCommand(cmd) =>
+ val output = runSqlHive(cmd)
+
+ if (output.size == 0) {
+ emptyResult
+ } else {
+ val asRows = output.map(r => new GenericRow(r.split("\t").asInstanceOf[Array[Any]]))
+ sparkContext.parallelize(asRows, 1)
+ }
+ case _ =>
+ executedPlan.execute.map(_.copy())
+ }
+
+ protected val primitiveTypes =
+ Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType,
+ ShortType, DecimalType)
+
+ protected def toHiveString(a: (Any, DataType)): String = a match {
+ case (struct: Row, StructType(fields)) =>
+ struct.zip(fields).map {
+ case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}"""
+ }.mkString("{", ",", "}")
+ case (seq: Seq[_], ArrayType(typ))=>
+ seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]")
+ case (map: Map[_,_], MapType(kType, vType)) =>
+ map.map {
+ case (key, value) =>
+ toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType))
+ }.toSeq.sorted.mkString("{", ",", "}")
+ case (null, _) => "NULL"
+ case (other, tpe) if primitiveTypes contains tpe => other.toString
+ }
+
+ /** Hive outputs fields of structs slightly differently than top level attributes. */
+ protected def toHiveStructString(a: (Any, DataType)): String = a match {
+ case (struct: Row, StructType(fields)) =>
+ struct.zip(fields).map {
+ case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}"""
+ }.mkString("{", ",", "}")
+ case (seq: Seq[_], ArrayType(typ))=>
+ seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]")
+ case (map: Map[_,_], MapType(kType, vType)) =>
+ map.map {
+ case (key, value) =>
+ toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType))
+ }.toSeq.sorted.mkString("{", ",", "}")
+ case (null, _) => "null"
+ case (s: String, StringType) => "\"" + s + "\""
+ case (other, tpe) if primitiveTypes contains tpe => other.toString
+ }
+
+ /**
+ * Returns the result as a hive compatible sequence of strings. For native commands, the
+ * execution is simply passed back to Hive.
+ */
+ def stringResult(): Seq[String] = analyzed match {
+ case NativeCommand(cmd) => runSqlHive(cmd)
+ case ExplainCommand(plan) => new QueryExecution { val logical = plan }.toString.split("\n")
+ case query =>
+ val result: Seq[Seq[Any]] = toRdd.collect().toSeq
+ // We need the types so we can output struct field names
+ val types = analyzed.output.map(_.dataType)
+ // Reformat to match hive tab delimited output.
+ val asString = result.map(_.zip(types).map(toHiveString)).map(_.mkString("\t")).toSeq
+ asString
+ }
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
new file mode 100644
index 0000000000..e4d50722ce
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+
+import scala.util.parsing.combinator.RegexParsers
+
+import org.apache.hadoop.hive.metastore.api.{FieldSchema, StorageDescriptor, SerDeInfo}
+import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition}
+import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table}
+import org.apache.hadoop.hive.ql.plan.TableDesc
+import org.apache.hadoop.hive.ql.session.SessionState
+import org.apache.hadoop.hive.serde2.Deserializer
+
+import catalyst.analysis.Catalog
+import catalyst.expressions._
+import catalyst.plans.logical
+import catalyst.plans.logical._
+import catalyst.rules._
+import catalyst.types._
+
+import scala.collection.JavaConversions._
+
+class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Logging {
+ import HiveMetastoreTypes._
+
+ val client = Hive.get(hive.hiveconf)
+
+ def lookupRelation(
+ db: Option[String],
+ tableName: String,
+ alias: Option[String]): LogicalPlan = {
+ val databaseName = db.getOrElse(hive.sessionState.getCurrentDatabase())
+ val table = client.getTable(databaseName, tableName)
+ val partitions: Seq[Partition] =
+ if (table.isPartitioned) {
+ client.getPartitions(table)
+ } else {
+ Nil
+ }
+
+ // Since HiveQL is case insensitive for table names we make them all lowercase.
+ MetastoreRelation(
+ databaseName.toLowerCase,
+ tableName.toLowerCase,
+ alias)(table.getTTable, partitions.map(part => part.getTPartition))
+ }
+
+ def createTable(databaseName: String, tableName: String, schema: Seq[Attribute]) {
+ val table = new Table(databaseName, tableName)
+ val hiveSchema =
+ schema.map(attr => new FieldSchema(attr.name, toMetastoreType(attr.dataType), ""))
+ table.setFields(hiveSchema)
+
+ val sd = new StorageDescriptor()
+ table.getTTable.setSd(sd)
+ sd.setCols(hiveSchema)
+
+ // TODO: THESE ARE ALL DEFAULTS, WE NEED TO PARSE / UNDERSTAND the output specs.
+ sd.setCompressed(false)
+ sd.setParameters(Map[String, String]())
+ sd.setInputFormat("org.apache.hadoop.mapred.TextInputFormat")
+ sd.setOutputFormat("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")
+ val serDeInfo = new SerDeInfo()
+ serDeInfo.setName(tableName)
+ serDeInfo.setSerializationLib("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")
+ serDeInfo.setParameters(Map[String, String]())
+ sd.setSerdeInfo(serDeInfo)
+ client.createTable(table)
+ }
+
+ /**
+ * Creates any tables required for query execution.
+ * For example, because of a CREATE TABLE X AS statement.
+ */
+ object CreateTables extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case InsertIntoCreatedTable(db, tableName, child) =>
+ val databaseName = db.getOrElse(SessionState.get.getCurrentDatabase())
+
+ createTable(databaseName, tableName, child.output)
+
+ InsertIntoTable(
+ lookupRelation(Some(databaseName), tableName, None).asInstanceOf[BaseRelation],
+ Map.empty,
+ child,
+ overwrite = false)
+ }
+ }
+
+ /**
+ * Casts input data to correct data types according to table definition before inserting into
+ * that table.
+ */
+ object PreInsertionCasts extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transform {
+ // Wait until children are resolved
+ case p: LogicalPlan if !p.childrenResolved => p
+
+ case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) =>
+ val childOutputDataTypes = child.output.map(_.dataType)
+ // Only check attributes, not partitionKeys since they are always strings.
+ // TODO: Fully support inserting into partitioned tables.
+ val tableOutputDataTypes = table.attributes.map(_.dataType)
+
+ if (childOutputDataTypes == tableOutputDataTypes) {
+ p
+ } else {
+ // Only do the casting when child output data types differ from table output data types.
+ val castedChildOutput = child.output.zip(table.output).map {
+ case (input, table) if input.dataType != table.dataType =>
+ Alias(Cast(input, table.dataType), input.name)()
+ case (input, _) => input
+ }
+
+ p.copy(child = logical.Project(castedChildOutput, child))
+ }
+ }
+ }
+
+ /**
+ * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore.
+ * For now, if this functionallity is desired mix in the in-memory [[OverrideCatalog]].
+ */
+ override def registerTable(
+ databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = ???
+}
+
+object HiveMetastoreTypes extends RegexParsers {
+ protected lazy val primitiveType: Parser[DataType] =
+ "string" ^^^ StringType |
+ "float" ^^^ FloatType |
+ "int" ^^^ IntegerType |
+ "tinyint" ^^^ ShortType |
+ "double" ^^^ DoubleType |
+ "bigint" ^^^ LongType |
+ "binary" ^^^ BinaryType |
+ "boolean" ^^^ BooleanType |
+ "decimal" ^^^ DecimalType |
+ "varchar\\((\\d+)\\)".r ^^^ StringType
+
+ protected lazy val arrayType: Parser[DataType] =
+ "array" ~> "<" ~> dataType <~ ">" ^^ ArrayType
+
+ protected lazy val mapType: Parser[DataType] =
+ "map" ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ {
+ case t1 ~ _ ~ t2 => MapType(t1, t2)
+ }
+
+ protected lazy val structField: Parser[StructField] =
+ "[a-zA-Z0-9]*".r ~ ":" ~ dataType ^^ {
+ case name ~ _ ~ tpe => StructField(name, tpe, nullable = true)
+ }
+
+ protected lazy val structType: Parser[DataType] =
+ "struct" ~> "<" ~> repsep(structField,",") <~ ">" ^^ StructType
+
+ protected lazy val dataType: Parser[DataType] =
+ arrayType |
+ mapType |
+ structType |
+ primitiveType
+
+ def toDataType(metastoreType: String): DataType = parseAll(dataType, metastoreType) match {
+ case Success(result, _) => result
+ case failure: NoSuccess => sys.error(s"Unsupported dataType: $metastoreType")
+ }
+
+ def toMetastoreType(dt: DataType): String = dt match {
+ case ArrayType(elementType) => s"array<${toMetastoreType(elementType)}>"
+ case StructType(fields) =>
+ s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>"
+ case MapType(keyType, valueType) =>
+ s"map<${toMetastoreType(keyType)},${toMetastoreType(valueType)}>"
+ case StringType => "string"
+ case FloatType => "float"
+ case IntegerType => "int"
+ case ShortType =>"tinyint"
+ case DoubleType => "double"
+ case LongType => "bigint"
+ case BinaryType => "binary"
+ case BooleanType => "boolean"
+ case DecimalType => "decimal"
+ }
+}
+
+case class MetastoreRelation(databaseName: String, tableName: String, alias: Option[String])
+ (val table: TTable, val partitions: Seq[TPartition])
+ extends BaseRelation {
+ // TODO: Can we use org.apache.hadoop.hive.ql.metadata.Table as the type of table and
+ // use org.apache.hadoop.hive.ql.metadata.Partition as the type of elements of partitions.
+ // Right now, using org.apache.hadoop.hive.ql.metadata.Table and
+ // org.apache.hadoop.hive.ql.metadata.Partition will cause a NotSerializableException
+ // which indicates the SerDe we used is not Serializable.
+
+ def hiveQlTable = new Table(table)
+
+ def hiveQlPartitions = partitions.map { p =>
+ new Partition(hiveQlTable, p)
+ }
+
+ override def isPartitioned = hiveQlTable.isPartitioned
+
+ val tableDesc = new TableDesc(
+ Class.forName(hiveQlTable.getSerializationLib).asInstanceOf[Class[Deserializer]],
+ hiveQlTable.getInputFormatClass,
+ // The class of table should be org.apache.hadoop.hive.ql.metadata.Table because
+ // getOutputFormatClass will use HiveFileFormatUtils.getOutputFormatSubstitute to
+ // substitute some output formats, e.g. substituting SequenceFileOutputFormat to
+ // HiveSequenceFileOutputFormat.
+ hiveQlTable.getOutputFormatClass,
+ hiveQlTable.getMetadata
+ )
+
+ implicit class SchemaAttribute(f: FieldSchema) {
+ def toAttribute = AttributeReference(
+ f.getName,
+ HiveMetastoreTypes.toDataType(f.getType),
+ // Since data can be dumped in randomly with no validation, everything is nullable.
+ nullable = true
+ )(qualifiers = tableName +: alias.toSeq)
+ }
+
+ // Must be a stable value since new attributes are born here.
+ val partitionKeys = hiveQlTable.getPartitionKeys.map(_.toAttribute)
+
+ /** Non-partitionKey attributes */
+ val attributes = table.getSd.getCols.map(_.toAttribute)
+
+ val output = attributes ++ partitionKeys
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
new file mode 100644
index 0000000000..4f33a293c3
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -0,0 +1,966 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+
+import scala.collection.JavaConversions._
+
+import org.apache.hadoop.hive.ql.lib.Node
+import org.apache.hadoop.hive.ql.parse._
+import org.apache.hadoop.hive.ql.plan.PlanUtils
+
+import catalyst.analysis._
+import catalyst.expressions._
+import catalyst.plans._
+import catalyst.plans.logical
+import catalyst.plans.logical._
+import catalyst.types._
+
+/**
+ * Used when we need to start parsing the AST before deciding that we are going to pass the command
+ * back for Hive to execute natively. Will be replaced with a native command that contains the
+ * cmd string.
+ */
+case object NativePlaceholder extends Command
+
+case class DfsCommand(cmd: String) extends Command
+
+case class ShellCommand(cmd: String) extends Command
+
+case class SourceCommand(filePath: String) extends Command
+
+case class AddJar(jarPath: String) extends Command
+
+case class AddFile(filePath: String) extends Command
+
+/** Provides a mapping from HiveQL statments to catalyst logical plans and expression trees. */
+object HiveQl {
+ protected val nativeCommands = Seq(
+ "TOK_DESCFUNCTION",
+ "TOK_DESCTABLE",
+ "TOK_DESCDATABASE",
+ "TOK_SHOW_TABLESTATUS",
+ "TOK_SHOWDATABASES",
+ "TOK_SHOWFUNCTIONS",
+ "TOK_SHOWINDEXES",
+ "TOK_SHOWINDEXES",
+ "TOK_SHOWPARTITIONS",
+ "TOK_SHOWTABLES",
+
+ "TOK_LOCKTABLE",
+ "TOK_SHOWLOCKS",
+ "TOK_UNLOCKTABLE",
+
+ "TOK_CREATEROLE",
+ "TOK_DROPROLE",
+ "TOK_GRANT",
+ "TOK_GRANT_ROLE",
+ "TOK_REVOKE",
+ "TOK_SHOW_GRANT",
+ "TOK_SHOW_ROLE_GRANT",
+
+ "TOK_CREATEFUNCTION",
+ "TOK_DROPFUNCTION",
+
+ "TOK_ANALYZE",
+ "TOK_ALTERDATABASE_PROPERTIES",
+ "TOK_ALTERINDEX_PROPERTIES",
+ "TOK_ALTERINDEX_REBUILD",
+ "TOK_ALTERTABLE_ADDCOLS",
+ "TOK_ALTERTABLE_ADDPARTS",
+ "TOK_ALTERTABLE_ALTERPARTS",
+ "TOK_ALTERTABLE_ARCHIVE",
+ "TOK_ALTERTABLE_CLUSTER_SORT",
+ "TOK_ALTERTABLE_DROPPARTS",
+ "TOK_ALTERTABLE_PARTITION",
+ "TOK_ALTERTABLE_PROPERTIES",
+ "TOK_ALTERTABLE_RENAME",
+ "TOK_ALTERTABLE_RENAMECOL",
+ "TOK_ALTERTABLE_REPLACECOLS",
+ "TOK_ALTERTABLE_SKEWED",
+ "TOK_ALTERTABLE_TOUCH",
+ "TOK_ALTERTABLE_UNARCHIVE",
+ "TOK_ANALYZE",
+ "TOK_CREATEDATABASE",
+ "TOK_CREATEFUNCTION",
+ "TOK_CREATEINDEX",
+ "TOK_DROPDATABASE",
+ "TOK_DROPINDEX",
+ "TOK_DROPTABLE",
+ "TOK_MSCK",
+
+ // TODO(marmbrus): Figure out how view are expanded by hive, as we might need to handle this.
+ "TOK_ALTERVIEW_ADDPARTS",
+ "TOK_ALTERVIEW_AS",
+ "TOK_ALTERVIEW_DROPPARTS",
+ "TOK_ALTERVIEW_PROPERTIES",
+ "TOK_ALTERVIEW_RENAME",
+ "TOK_CREATEVIEW",
+ "TOK_DROPVIEW",
+
+ "TOK_EXPORT",
+ "TOK_IMPORT",
+ "TOK_LOAD",
+
+ "TOK_SWITCHDATABASE"
+ )
+
+ /**
+ * A set of implicit transformations that allow Hive ASTNodes to be rewritten by transformations
+ * similar to [[catalyst.trees.TreeNode]].
+ *
+ * Note that this should be considered very experimental and is not indented as a replacement
+ * for TreeNode. Primarily it should be noted ASTNodes are not immutable and do not appear to
+ * have clean copy semantics. Therefore, users of this class should take care when
+ * copying/modifying trees that might be used elsewhere.
+ */
+ implicit class TransformableNode(n: ASTNode) {
+ /**
+ * Returns a copy of this node where `rule` has been recursively applied to it and all of its
+ * children. When `rule` does not apply to a given node it is left unchanged.
+ * @param rule the function use to transform this nodes children
+ */
+ def transform(rule: PartialFunction[ASTNode, ASTNode]): ASTNode = {
+ try {
+ val afterRule = rule.applyOrElse(n, identity[ASTNode])
+ afterRule.withChildren(
+ nilIfEmpty(afterRule.getChildren)
+ .asInstanceOf[Seq[ASTNode]]
+ .map(ast => Option(ast).map(_.transform(rule)).orNull))
+ } catch {
+ case e: Exception =>
+ println(dumpTree(n))
+ throw e
+ }
+ }
+
+ /**
+ * Returns a scala.Seq equivilent to [s] or Nil if [s] is null.
+ */
+ private def nilIfEmpty[A](s: java.util.List[A]): Seq[A] =
+ Option(s).map(_.toSeq).getOrElse(Nil)
+
+ /**
+ * Returns this ASTNode with the text changed to `newText``.
+ */
+ def withText(newText: String): ASTNode = {
+ n.token.asInstanceOf[org.antlr.runtime.CommonToken].setText(newText)
+ n
+ }
+
+ /**
+ * Returns this ASTNode with the children changed to `newChildren`.
+ */
+ def withChildren(newChildren: Seq[ASTNode]): ASTNode = {
+ (1 to n.getChildCount).foreach(_ => n.deleteChild(0))
+ n.addChildren(newChildren)
+ n
+ }
+
+ /**
+ * Throws an error if this is not equal to other.
+ *
+ * Right now this function only checks the name, type, text and children of the node
+ * for equality.
+ */
+ def checkEquals(other: ASTNode) {
+ def check(field: String, f: ASTNode => Any) = if (f(n) != f(other)) {
+ sys.error(s"$field does not match for trees. " +
+ s"'${f(n)}' != '${f(other)}' left: ${dumpTree(n)}, right: ${dumpTree(other)}")
+ }
+ check("name", _.getName)
+ check("type", _.getType)
+ check("text", _.getText)
+ check("numChildren", n => nilIfEmpty(n.getChildren).size)
+
+ val leftChildren = nilIfEmpty(n.getChildren).asInstanceOf[Seq[ASTNode]]
+ val rightChildren = nilIfEmpty(other.getChildren).asInstanceOf[Seq[ASTNode]]
+ leftChildren zip rightChildren foreach {
+ case (l,r) => l checkEquals r
+ }
+ }
+ }
+
+ class ParseException(sql: String, cause: Throwable)
+ extends Exception(s"Failed to parse: $sql", cause)
+
+ /**
+ * Returns the AST for the given SQL string.
+ */
+ def getAst(sql: String): ASTNode = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql))
+
+ /** Returns a LogicalPlan for a given HiveQL string. */
+ def parseSql(sql: String): LogicalPlan = {
+ try {
+ if (sql.toLowerCase.startsWith("set")) {
+ NativeCommand(sql)
+ } else if (sql.toLowerCase.startsWith("add jar")) {
+ AddJar(sql.drop(8))
+ } else if (sql.toLowerCase.startsWith("add file")) {
+ AddFile(sql.drop(9))
+ } else if (sql.startsWith("dfs")) {
+ DfsCommand(sql)
+ } else if (sql.startsWith("source")) {
+ SourceCommand(sql.split(" ").toSeq match { case Seq("source", filePath) => filePath })
+ } else if (sql.startsWith("!")) {
+ ShellCommand(sql.drop(1))
+ } else {
+ val tree = getAst(sql)
+
+ if (nativeCommands contains tree.getText) {
+ NativeCommand(sql)
+ } else {
+ nodeToPlan(tree) match {
+ case NativePlaceholder => NativeCommand(sql)
+ case other => other
+ }
+ }
+ }
+ } catch {
+ case e: Exception => throw new ParseException(sql, e)
+ }
+ }
+
+ def parseDdl(ddl: String): Seq[Attribute] = {
+ val tree =
+ try {
+ ParseUtils.findRootNonNullToken(
+ (new ParseDriver).parse(ddl, null /* no context required for parsing alone */))
+ } catch {
+ case pe: org.apache.hadoop.hive.ql.parse.ParseException =>
+ throw new RuntimeException(s"Failed to parse ddl: '$ddl'", pe)
+ }
+ assert(tree.asInstanceOf[ASTNode].getText == "TOK_CREATETABLE", "Only CREATE TABLE supported.")
+ val tableOps = tree.getChildren
+ val colList =
+ tableOps
+ .find(_.asInstanceOf[ASTNode].getText == "TOK_TABCOLLIST")
+ .getOrElse(sys.error("No columnList!")).getChildren
+
+ colList.map(nodeToAttribute)
+ }
+
+ /** Extractor for matching Hive's AST Tokens. */
+ object Token {
+ /** @return matches of the form (tokenName, children). */
+ def unapply(t: Any): Option[(String, Seq[ASTNode])] = t match {
+ case t: ASTNode =>
+ Some((t.getText,
+ Option(t.getChildren).map(_.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]]))
+ case _ => None
+ }
+ }
+
+ protected def getClauses(clauseNames: Seq[String], nodeList: Seq[ASTNode]): Seq[Option[Node]] = {
+ var remainingNodes = nodeList
+ val clauses = clauseNames.map { clauseName =>
+ val (matches, nonMatches) = remainingNodes.partition(_.getText.toUpperCase == clauseName)
+ remainingNodes = nonMatches ++ (if (matches.nonEmpty) matches.tail else Nil)
+ matches.headOption
+ }
+
+ assert(remainingNodes.isEmpty,
+ s"Unhandled clauses: ${remainingNodes.map(dumpTree(_)).mkString("\n")}")
+ clauses
+ }
+
+ def getClause(clauseName: String, nodeList: Seq[Node]) =
+ getClauseOption(clauseName, nodeList).getOrElse(sys.error(
+ s"Expected clause $clauseName missing from ${nodeList.map(dumpTree(_)).mkString("\n")}"))
+
+ def getClauseOption(clauseName: String, nodeList: Seq[Node]): Option[Node] = {
+ nodeList.filter { case ast: ASTNode => ast.getText == clauseName } match {
+ case Seq(oneMatch) => Some(oneMatch)
+ case Seq() => None
+ case _ => sys.error(s"Found multiple instances of clause $clauseName")
+ }
+ }
+
+ protected def nodeToAttribute(node: Node): Attribute = node match {
+ case Token("TOK_TABCOL", Token(colName, Nil) :: dataType :: Nil) =>
+ AttributeReference(colName, nodeToDataType(dataType), true)()
+
+ case a: ASTNode =>
+ throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ")
+ }
+
+ protected def nodeToDataType(node: Node): DataType = node match {
+ case Token("TOK_BIGINT", Nil) => IntegerType
+ case Token("TOK_INT", Nil) => IntegerType
+ case Token("TOK_TINYINT", Nil) => IntegerType
+ case Token("TOK_SMALLINT", Nil) => IntegerType
+ case Token("TOK_BOOLEAN", Nil) => BooleanType
+ case Token("TOK_STRING", Nil) => StringType
+ case Token("TOK_FLOAT", Nil) => FloatType
+ case Token("TOK_DOUBLE", Nil) => FloatType
+ case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType))
+ case Token("TOK_STRUCT",
+ Token("TOK_TABCOLLIST", fields) :: Nil) =>
+ StructType(fields.map(nodeToStructField))
+ case Token("TOK_MAP",
+ keyType ::
+ valueType :: Nil) =>
+ MapType(nodeToDataType(keyType), nodeToDataType(valueType))
+ case a: ASTNode =>
+ throw new NotImplementedError(s"No parse rules for DataType:\n ${dumpTree(a).toString} ")
+ }
+
+ protected def nodeToStructField(node: Node): StructField = node match {
+ case Token("TOK_TABCOL",
+ Token(fieldName, Nil) ::
+ dataType :: Nil) =>
+ StructField(fieldName, nodeToDataType(dataType), nullable = true)
+ case Token("TOK_TABCOL",
+ Token(fieldName, Nil) ::
+ dataType ::
+ _ /* comment */:: Nil) =>
+ StructField(fieldName, nodeToDataType(dataType), nullable = true)
+ case a: ASTNode =>
+ throw new NotImplementedError(s"No parse rules for StructField:\n ${dumpTree(a).toString} ")
+ }
+
+ protected def nameExpressions(exprs: Seq[Expression]): Seq[NamedExpression] = {
+ exprs.zipWithIndex.map {
+ case (ne: NamedExpression, _) => ne
+ case (e, i) => Alias(e, s"c_$i")()
+ }
+ }
+
+ protected def nodeToPlan(node: Node): LogicalPlan = node match {
+ // Just fake explain for any of the native commands.
+ case Token("TOK_EXPLAIN", explainArgs) if nativeCommands contains explainArgs.head.getText =>
+ NoRelation
+ case Token("TOK_EXPLAIN", explainArgs) =>
+ // Ignore FORMATTED if present.
+ val Some(query) :: _ :: _ :: Nil =
+ getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs)
+ // TODO: support EXTENDED?
+ ExplainCommand(nodeToPlan(query))
+
+ case Token("TOK_CREATETABLE", children)
+ if children.collect { case t@Token("TOK_QUERY", _) => t }.nonEmpty =>
+ // TODO: Parse other clauses.
+ // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL
+ val (
+ Some(tableNameParts) ::
+ _ /* likeTable */ ::
+ Some(query) +:
+ notImplemented) =
+ getClauses(
+ Seq(
+ "TOK_TABNAME",
+ "TOK_LIKETABLE",
+ "TOK_QUERY",
+ "TOK_IFNOTEXISTS",
+ "TOK_TABLECOMMENT",
+ "TOK_TABCOLLIST",
+ "TOK_TABLEPARTCOLS", // Partitioned by
+ "TOK_TABLEBUCKETS", // Clustered by
+ "TOK_TABLESKEWED", // Skewed by
+ "TOK_TABLEROWFORMAT",
+ "TOK_TABLESERIALIZER",
+ "TOK_FILEFORMAT_GENERIC", // For file formats not natively supported by Hive.
+ "TOK_TBLSEQUENCEFILE", // Stored as SequenceFile
+ "TOK_TBLTEXTFILE", // Stored as TextFile
+ "TOK_TBLRCFILE", // Stored as RCFile
+ "TOK_TBLORCFILE", // Stored as ORC File
+ "TOK_TABLEFILEFORMAT", // User-provided InputFormat and OutputFormat
+ "TOK_STORAGEHANDLER", // Storage handler
+ "TOK_TABLELOCATION",
+ "TOK_TABLEPROPERTIES"),
+ children)
+ if (notImplemented.exists(token => !token.isEmpty)) {
+ throw new NotImplementedError(
+ s"Unhandled clauses: ${notImplemented.flatten.map(dumpTree(_)).mkString("\n")}")
+ }
+
+ val (db, tableName) =
+ tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match {
+ case Seq(tableOnly) => (None, tableOnly)
+ case Seq(databaseName, table) => (Some(databaseName), table)
+ }
+ InsertIntoCreatedTable(db, tableName, nodeToPlan(query))
+
+ // If its not a "CREATE TABLE AS" like above then just pass it back to hive as a native command.
+ case Token("TOK_CREATETABLE", _) => NativePlaceholder
+
+ case Token("TOK_QUERY",
+ Token("TOK_FROM", fromClause :: Nil) ::
+ insertClauses) =>
+
+ // Return one query for each insert clause.
+ val queries = insertClauses.map { case Token("TOK_INSERT", singleInsert) =>
+ val (
+ intoClause ::
+ destClause ::
+ selectClause ::
+ selectDistinctClause ::
+ whereClause ::
+ groupByClause ::
+ orderByClause ::
+ sortByClause ::
+ clusterByClause ::
+ distributeByClause ::
+ limitClause ::
+ lateralViewClause :: Nil) = {
+ getClauses(
+ Seq(
+ "TOK_INSERT_INTO",
+ "TOK_DESTINATION",
+ "TOK_SELECT",
+ "TOK_SELECTDI",
+ "TOK_WHERE",
+ "TOK_GROUPBY",
+ "TOK_ORDERBY",
+ "TOK_SORTBY",
+ "TOK_CLUSTERBY",
+ "TOK_DISTRIBUTEBY",
+ "TOK_LIMIT",
+ "TOK_LATERAL_VIEW"),
+ singleInsert)
+ }
+
+ val relations = nodeToRelation(fromClause)
+ val withWhere = whereClause.map { whereNode =>
+ val Seq(whereExpr) = whereNode.getChildren.toSeq
+ Filter(nodeToExpr(whereExpr), relations)
+ }.getOrElse(relations)
+
+ val select =
+ (selectClause orElse selectDistinctClause).getOrElse(sys.error("No select clause."))
+
+ // Script transformations are expressed as a select clause with a single expression of type
+ // TOK_TRANSFORM
+ val transformation = select.getChildren.head match {
+ case Token("TOK_SELEXPR",
+ Token("TOK_TRANSFORM",
+ Token("TOK_EXPLIST", inputExprs) ::
+ Token("TOK_SERDE", Nil) ::
+ Token("TOK_RECORDWRITER", writerClause) ::
+ // TODO: Need to support other types of (in/out)put
+ Token(script, Nil) ::
+ Token("TOK_SERDE", serdeClause) ::
+ Token("TOK_RECORDREADER", readerClause) ::
+ outputClause :: Nil) :: Nil) =>
+
+ val output = outputClause match {
+ case Token("TOK_ALIASLIST", aliases) =>
+ aliases.map { case Token(name, Nil) => AttributeReference(name, StringType)() }
+ case Token("TOK_TABCOLLIST", attributes) =>
+ attributes.map { case Token("TOK_TABCOL", Token(name, Nil) :: dataType :: Nil) =>
+ AttributeReference(name, nodeToDataType(dataType))() }
+ }
+ val unescapedScript = BaseSemanticAnalyzer.unescapeSQLString(script)
+
+ Some(
+ logical.ScriptTransformation(
+ inputExprs.map(nodeToExpr),
+ unescapedScript,
+ output,
+ withWhere))
+ case _ => None
+ }
+
+ val withLateralView = lateralViewClause.map { lv =>
+ val Token("TOK_SELECT",
+ Token("TOK_SELEXPR", clauses) :: Nil) = lv.getChildren.head
+
+ val alias =
+ getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText
+
+ Generate(
+ nodesToGenerator(clauses),
+ join = true,
+ outer = false,
+ Some(alias.toLowerCase),
+ withWhere)
+ }.getOrElse(withWhere)
+
+
+ // The projection of the query can either be a normal projection, an aggregation
+ // (if there is a group by) or a script transformation.
+ val withProject = transformation.getOrElse {
+ // Not a transformation so must be either project or aggregation.
+ val selectExpressions = nameExpressions(select.getChildren.flatMap(selExprNodeToExpr))
+
+ groupByClause match {
+ case Some(groupBy) =>
+ Aggregate(groupBy.getChildren.map(nodeToExpr), selectExpressions, withLateralView)
+ case None =>
+ Project(selectExpressions, withLateralView)
+ }
+ }
+
+ val withDistinct =
+ if (selectDistinctClause.isDefined) Distinct(withProject) else withProject
+
+ val withSort =
+ (orderByClause, sortByClause, distributeByClause, clusterByClause) match {
+ case (Some(totalOrdering), None, None, None) =>
+ Sort(totalOrdering.getChildren.map(nodeToSortOrder), withDistinct)
+ case (None, Some(perPartitionOrdering), None, None) =>
+ SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder), withDistinct)
+ case (None, None, Some(partitionExprs), None) =>
+ Repartition(partitionExprs.getChildren.map(nodeToExpr), withDistinct)
+ case (None, Some(perPartitionOrdering), Some(partitionExprs), None) =>
+ SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder),
+ Repartition(partitionExprs.getChildren.map(nodeToExpr), withDistinct))
+ case (None, None, None, Some(clusterExprs)) =>
+ SortPartitions(clusterExprs.getChildren.map(nodeToExpr).map(SortOrder(_, Ascending)),
+ Repartition(clusterExprs.getChildren.map(nodeToExpr), withDistinct))
+ case (None, None, None, None) => withDistinct
+ case _ => sys.error("Unsupported set of ordering / distribution clauses.")
+ }
+
+ val withLimit =
+ limitClause.map(l => nodeToExpr(l.getChildren.head))
+ .map(StopAfter(_, withSort))
+ .getOrElse(withSort)
+
+ // TOK_INSERT_INTO means to add files to the table.
+ // TOK_DESTINATION means to overwrite the table.
+ val resultDestination =
+ (intoClause orElse destClause).getOrElse(sys.error("No destination found."))
+ val overwrite = if (intoClause.isEmpty) true else false
+ nodeToDest(
+ resultDestination,
+ withLimit,
+ overwrite)
+ }
+
+ // If there are multiple INSERTS just UNION them together into on query.
+ queries.reduceLeft(Union)
+
+ case Token("TOK_UNION", left :: right :: Nil) => Union(nodeToPlan(left), nodeToPlan(right))
+
+ case a: ASTNode =>
+ throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ")
+ }
+
+ val allJoinTokens = "(TOK_.*JOIN)".r
+ val laterViewToken = "TOK_LATERAL_VIEW(.*)".r
+ def nodeToRelation(node: Node): LogicalPlan = node match {
+ case Token("TOK_SUBQUERY",
+ query :: Token(alias, Nil) :: Nil) =>
+ Subquery(alias, nodeToPlan(query))
+
+ case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) =>
+ val Token("TOK_SELECT",
+ Token("TOK_SELEXPR", clauses) :: Nil) = selectClause
+
+ val alias = getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText
+
+ Generate(
+ nodesToGenerator(clauses),
+ join = true,
+ outer = isOuter.nonEmpty,
+ Some(alias.toLowerCase),
+ nodeToRelation(relationClause))
+
+ /* All relations, possibly with aliases or sampling clauses. */
+ case Token("TOK_TABREF", clauses) =>
+ // If the last clause is not a token then it's the alias of the table.
+ val (nonAliasClauses, aliasClause) =
+ if (clauses.last.getText.startsWith("TOK")) {
+ (clauses, None)
+ } else {
+ (clauses.dropRight(1), Some(clauses.last))
+ }
+
+ val (Some(tableNameParts) ::
+ splitSampleClause ::
+ bucketSampleClause :: Nil) = {
+ getClauses(Seq("TOK_TABNAME", "TOK_TABLESPLITSAMPLE", "TOK_TABLEBUCKETSAMPLE"),
+ nonAliasClauses)
+ }
+
+ val (db, tableName) =
+ tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match {
+ case Seq(tableOnly) => (None, tableOnly)
+ case Seq(databaseName, table) => (Some(databaseName), table)
+ }
+ val alias = aliasClause.map { case Token(a, Nil) => cleanIdentifier(a) }
+ val relation = UnresolvedRelation(db, tableName, alias)
+
+ // Apply sampling if requested.
+ (bucketSampleClause orElse splitSampleClause).map {
+ case Token("TOK_TABLESPLITSAMPLE",
+ Token("TOK_ROWCOUNT", Nil) ::
+ Token(count, Nil) :: Nil) =>
+ StopAfter(Literal(count.toInt), relation)
+ case Token("TOK_TABLESPLITSAMPLE",
+ Token("TOK_PERCENT", Nil) ::
+ Token(fraction, Nil) :: Nil) =>
+ Sample(fraction.toDouble, withReplacement = false, (math.random * 1000).toInt, relation)
+ case Token("TOK_TABLEBUCKETSAMPLE",
+ Token(numerator, Nil) ::
+ Token(denominator, Nil) :: Nil) =>
+ val fraction = numerator.toDouble / denominator.toDouble
+ Sample(fraction, withReplacement = false, (math.random * 1000).toInt, relation)
+ case a: ASTNode =>
+ throw new NotImplementedError(
+ s"""No parse rules for sampling clause: ${a.getType}, text: ${a.getText} :
+ |${dumpTree(a).toString}" +
+ """.stripMargin)
+ }.getOrElse(relation)
+
+ case Token("TOK_UNIQUEJOIN", joinArgs) =>
+ val tableOrdinals =
+ joinArgs.zipWithIndex.filter {
+ case (arg, i) => arg.getText == "TOK_TABREF"
+ }.map(_._2)
+
+ val isPreserved = tableOrdinals.map(i => (i - 1 < 0) || joinArgs(i - 1).getText == "PRESERVE")
+ val tables = tableOrdinals.map(i => nodeToRelation(joinArgs(i)))
+ val joinExpressions = tableOrdinals.map(i => joinArgs(i + 1).getChildren.map(nodeToExpr))
+
+ val joinConditions = joinExpressions.sliding(2).map {
+ case Seq(c1, c2) =>
+ val predicates = (c1, c2).zipped.map { case (e1, e2) => Equals(e1, e2): Expression }
+ predicates.reduceLeft(And)
+ }.toBuffer
+
+ val joinType = isPreserved.sliding(2).map {
+ case Seq(true, true) => FullOuter
+ case Seq(true, false) => LeftOuter
+ case Seq(false, true) => RightOuter
+ case Seq(false, false) => Inner
+ }.toBuffer
+
+ val joinedTables = tables.reduceLeft(Join(_,_, Inner, None))
+
+ // Must be transform down.
+ val joinedResult = joinedTables transform {
+ case j: Join =>
+ j.copy(
+ condition = Some(joinConditions.remove(joinConditions.length - 1)),
+ joinType = joinType.remove(joinType.length - 1))
+ }
+
+ val groups = (0 until joinExpressions.head.size).map(i => Coalesce(joinExpressions.map(_(i))))
+
+ // Unique join is not really the same as an outer join so we must group together results where
+ // the joinExpressions are the same, taking the First of each value is only okay because the
+ // user of a unique join is implicitly promising that there is only one result.
+ // TODO: This doesn't actually work since [[Star]] is not a valid aggregate expression.
+ // instead we should figure out how important supporting this feature is and whether it is
+ // worth the number of hacks that will be required to implement it. Namely, we need to add
+ // some sort of mapped star expansion that would expand all child output row to be similarly
+ // named output expressions where some aggregate expression has been applied (i.e. First).
+ ??? /// Aggregate(groups, Star(None, First(_)) :: Nil, joinedResult)
+
+ case Token(allJoinTokens(joinToken),
+ relation1 ::
+ relation2 :: other) =>
+ assert(other.size <= 1, s"Unhandled join child ${other}")
+ val joinType = joinToken match {
+ case "TOK_JOIN" => Inner
+ case "TOK_RIGHTOUTERJOIN" => RightOuter
+ case "TOK_LEFTOUTERJOIN" => LeftOuter
+ case "TOK_FULLOUTERJOIN" => FullOuter
+ }
+ assert(other.size <= 1, "Unhandled join clauses.")
+ Join(nodeToRelation(relation1),
+ nodeToRelation(relation2),
+ joinType,
+ other.headOption.map(nodeToExpr))
+
+ case a: ASTNode =>
+ throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ")
+ }
+
+ def nodeToSortOrder(node: Node): SortOrder = node match {
+ case Token("TOK_TABSORTCOLNAMEASC", sortExpr :: Nil) =>
+ SortOrder(nodeToExpr(sortExpr), Ascending)
+ case Token("TOK_TABSORTCOLNAMEDESC", sortExpr :: Nil) =>
+ SortOrder(nodeToExpr(sortExpr), Descending)
+
+ case a: ASTNode =>
+ throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ")
+ }
+
+ val destinationToken = "TOK_DESTINATION|TOK_INSERT_INTO".r
+ protected def nodeToDest(
+ node: Node,
+ query: LogicalPlan,
+ overwrite: Boolean): LogicalPlan = node match {
+ case Token(destinationToken(),
+ Token("TOK_DIR",
+ Token("TOK_TMP_FILE", Nil) :: Nil) :: Nil) =>
+ query
+
+ case Token(destinationToken(),
+ Token("TOK_TAB",
+ tableArgs) :: Nil) =>
+ val Some(tableNameParts) :: partitionClause :: Nil =
+ getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs)
+
+ val (db, tableName) =
+ tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match {
+ case Seq(tableOnly) => (None, tableOnly)
+ case Seq(databaseName, table) => (Some(databaseName), table)
+ }
+
+ val partitionKeys = partitionClause.map(_.getChildren.map {
+ // Parse partitions. We also make keys case insensitive.
+ case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) =>
+ cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value))
+ case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) =>
+ cleanIdentifier(key.toLowerCase) -> None
+ }.toMap).getOrElse(Map.empty)
+
+ if (partitionKeys.values.exists(p => p.isEmpty)) {
+ throw new NotImplementedError(s"Do not support INSERT INTO/OVERWRITE with" +
+ s"dynamic partitioning.")
+ }
+
+ InsertIntoTable(UnresolvedRelation(db, tableName, None), partitionKeys, query, overwrite)
+
+ case a: ASTNode =>
+ throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ")
+ }
+
+ protected def selExprNodeToExpr(node: Node): Option[Expression] = node match {
+ case Token("TOK_SELEXPR",
+ e :: Nil) =>
+ Some(nodeToExpr(e))
+
+ case Token("TOK_SELEXPR",
+ e :: Token(alias, Nil) :: Nil) =>
+ Some(Alias(nodeToExpr(e), alias)())
+
+ /* Hints are ignored */
+ case Token("TOK_HINTLIST", _) => None
+
+ case a: ASTNode =>
+ throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ")
+ }
+
+
+ protected val escapedIdentifier = "`([^`]+)`".r
+ /** Strips backticks from ident if present */
+ protected def cleanIdentifier(ident: String): String = ident match {
+ case escapedIdentifier(i) => i
+ case plainIdent => plainIdent
+ }
+
+ val numericAstTypes = Seq(
+ HiveParser.Number,
+ HiveParser.TinyintLiteral,
+ HiveParser.SmallintLiteral,
+ HiveParser.BigintLiteral)
+
+ /* Case insensitive matches */
+ val COUNT = "(?i)COUNT".r
+ val AVG = "(?i)AVG".r
+ val SUM = "(?i)SUM".r
+ val RAND = "(?i)RAND".r
+ val AND = "(?i)AND".r
+ val OR = "(?i)OR".r
+ val NOT = "(?i)NOT".r
+ val TRUE = "(?i)TRUE".r
+ val FALSE = "(?i)FALSE".r
+
+ protected def nodeToExpr(node: Node): Expression = node match {
+ /* Attribute References */
+ case Token("TOK_TABLE_OR_COL",
+ Token(name, Nil) :: Nil) =>
+ UnresolvedAttribute(cleanIdentifier(name))
+ case Token(".", qualifier :: Token(attr, Nil) :: Nil) =>
+ nodeToExpr(qualifier) match {
+ case UnresolvedAttribute(qualifierName) =>
+ UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr))
+ // The precidence for . seems to be wrong, so [] binds tighter an we need to go inside to
+ // find the underlying attribute references.
+ case GetItem(UnresolvedAttribute(qualifierName), ordinal) =>
+ GetItem(UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr)), ordinal)
+ }
+
+ /* Stars (*) */
+ case Token("TOK_ALLCOLREF", Nil) => Star(None)
+ // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only
+ // has a single child which is tableName.
+ case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) =>
+ Star(Some(name))
+
+ /* Aggregate Functions */
+ case Token("TOK_FUNCTION", Token(AVG(), Nil) :: arg :: Nil) => Average(nodeToExpr(arg))
+ case Token("TOK_FUNCTION", Token(COUNT(), Nil) :: arg :: Nil) => Count(nodeToExpr(arg))
+ case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => Count(Literal(1))
+ case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => CountDistinct(args.map(nodeToExpr))
+ case Token("TOK_FUNCTION", Token(SUM(), Nil) :: arg :: Nil) => Sum(nodeToExpr(arg))
+ case Token("TOK_FUNCTIONDI", Token(SUM(), Nil) :: arg :: Nil) => SumDistinct(nodeToExpr(arg))
+
+ /* Casts */
+ case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), StringType)
+ case Token("TOK_FUNCTION", Token("TOK_VARCHAR", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), StringType)
+ case Token("TOK_FUNCTION", Token("TOK_INT", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), IntegerType)
+ case Token("TOK_FUNCTION", Token("TOK_BIGINT", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), LongType)
+ case Token("TOK_FUNCTION", Token("TOK_FLOAT", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), FloatType)
+ case Token("TOK_FUNCTION", Token("TOK_DOUBLE", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), DoubleType)
+ case Token("TOK_FUNCTION", Token("TOK_SMALLINT", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), ShortType)
+ case Token("TOK_FUNCTION", Token("TOK_TINYINT", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), ByteType)
+ case Token("TOK_FUNCTION", Token("TOK_BINARY", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), BinaryType)
+ case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), BooleanType)
+ case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), DecimalType)
+
+ /* Arithmetic */
+ case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child))
+ case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right))
+ case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right))
+ case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right))
+ case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right))
+ case Token("DIV", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right))
+ case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right))
+
+ /* Comparisons */
+ case Token("=", left :: right:: Nil) => Equals(nodeToExpr(left), nodeToExpr(right))
+ case Token("!=", left :: right:: Nil) => Not(Equals(nodeToExpr(left), nodeToExpr(right)))
+ case Token("<>", left :: right:: Nil) => Not(Equals(nodeToExpr(left), nodeToExpr(right)))
+ case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right))
+ case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right))
+ case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right))
+ case Token("<=", left :: right:: Nil) => LessThanOrEqual(nodeToExpr(left), nodeToExpr(right))
+ case Token("LIKE", left :: right:: Nil) =>
+ UnresolvedFunction("LIKE", Seq(nodeToExpr(left), nodeToExpr(right)))
+ case Token("RLIKE", left :: right:: Nil) =>
+ UnresolvedFunction("RLIKE", Seq(nodeToExpr(left), nodeToExpr(right)))
+ case Token("REGEXP", left :: right:: Nil) =>
+ UnresolvedFunction("REGEXP", Seq(nodeToExpr(left), nodeToExpr(right)))
+ case Token("TOK_FUNCTION", Token("TOK_ISNOTNULL", Nil) :: child :: Nil) =>
+ IsNotNull(nodeToExpr(child))
+ case Token("TOK_FUNCTION", Token("TOK_ISNULL", Nil) :: child :: Nil) =>
+ IsNull(nodeToExpr(child))
+ case Token("TOK_FUNCTION", Token("IN", Nil) :: value :: list) =>
+ In(nodeToExpr(value), list.map(nodeToExpr))
+
+ /* Boolean Logic */
+ case Token(AND(), left :: right:: Nil) => And(nodeToExpr(left), nodeToExpr(right))
+ case Token(OR(), left :: right:: Nil) => Or(nodeToExpr(left), nodeToExpr(right))
+ case Token(NOT(), child :: Nil) => Not(nodeToExpr(child))
+
+ /* Complex datatype manipulation */
+ case Token("[", child :: ordinal :: Nil) =>
+ GetItem(nodeToExpr(child), nodeToExpr(ordinal))
+
+ /* Other functions */
+ case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand
+
+ /* UDFs - Must be last otherwise will preempt built in functions */
+ case Token("TOK_FUNCTION", Token(name, Nil) :: args) =>
+ UnresolvedFunction(name, args.map(nodeToExpr))
+ case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) =>
+ UnresolvedFunction(name, Star(None) :: Nil)
+
+ /* Literals */
+ case Token("TOK_NULL", Nil) => Literal(null, NullType)
+ case Token(TRUE(), Nil) => Literal(true, BooleanType)
+ case Token(FALSE(), Nil) => Literal(false, BooleanType)
+ case Token("TOK_STRINGLITERALSEQUENCE", strings) =>
+ Literal(strings.map(s => BaseSemanticAnalyzer.unescapeSQLString(s.getText)).mkString)
+
+ // This code is adapted from
+ // /ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java#L223
+ case ast: ASTNode if numericAstTypes contains ast.getType =>
+ var v: Literal = null
+ try {
+ if (ast.getText.endsWith("L")) {
+ // Literal bigint.
+ v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toLong, LongType)
+ } else if (ast.getText.endsWith("S")) {
+ // Literal smallint.
+ v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toShort, ShortType)
+ } else if (ast.getText.endsWith("Y")) {
+ // Literal tinyint.
+ v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toByte, ByteType)
+ } else if (ast.getText.endsWith("BD")) {
+ // Literal decimal
+ val strVal = ast.getText.substring(0, ast.getText.length() - 2)
+ BigDecimal(strVal)
+ } else {
+ v = Literal(ast.getText.toDouble, DoubleType)
+ v = Literal(ast.getText.toLong, LongType)
+ v = Literal(ast.getText.toInt, IntegerType)
+ }
+ } catch {
+ case nfe: NumberFormatException => // Do nothing
+ }
+
+ if (v == null) {
+ sys.error(s"Failed to parse number ${ast.getText}")
+ } else {
+ v
+ }
+
+ case ast: ASTNode if ast.getType == HiveParser.StringLiteral =>
+ Literal(BaseSemanticAnalyzer.unescapeSQLString(ast.getText))
+
+ case a: ASTNode =>
+ throw new NotImplementedError(
+ s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText} :
+ |${dumpTree(a).toString}" +
+ """.stripMargin)
+ }
+
+
+ val explode = "(?i)explode".r
+ def nodesToGenerator(nodes: Seq[Node]): Generator = {
+ val function = nodes.head
+
+ val attributes = nodes.flatMap {
+ case Token(a, Nil) => a.toLowerCase :: Nil
+ case _ => Nil
+ }
+
+ function match {
+ case Token("TOK_FUNCTION", Token(explode(), Nil) :: child :: Nil) =>
+ Explode(attributes, nodeToExpr(child))
+
+ case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) =>
+ HiveGenericUdtf(functionName, attributes, children.map(nodeToExpr))
+
+ case a: ASTNode =>
+ throw new NotImplementedError(
+ s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText}, tree:
+ |${dumpTree(a).toString}
+ """.stripMargin)
+ }
+ }
+
+ def dumpTree(node: Node, builder: StringBuilder = new StringBuilder, indent: Int = 0)
+ : StringBuilder = {
+ node match {
+ case a: ASTNode => builder.append((" " * indent) + a.getText + "\n")
+ case other => sys.error(s"Non ASTNode encountered: $other")
+ }
+
+ Option(node.getChildren).map(_.toList).getOrElse(Nil).foreach(dumpTree(_, builder, indent + 1))
+ builder
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
new file mode 100644
index 0000000000..92d84208ab
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+
+import catalyst.expressions._
+import catalyst.planning._
+import catalyst.plans._
+import catalyst.plans.logical.{BaseRelation, LogicalPlan}
+
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.parquet.{ParquetRelation, InsertIntoParquetTable, ParquetTableScan}
+
+trait HiveStrategies {
+ // Possibly being too clever with types here... or not clever enough.
+ self: SQLContext#SparkPlanner =>
+
+ val hiveContext: HiveContext
+
+ object Scripts extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case logical.ScriptTransformation(input, script, output, child) =>
+ ScriptTransformation(input, script, output, planLater(child))(hiveContext) :: Nil
+ case _ => Nil
+ }
+ }
+
+ object DataSinks extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) =>
+ InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil
+ case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) =>
+ InsertIntoParquetTable(table, planLater(child))(hiveContext.sparkContext) :: Nil
+ case _ => Nil
+ }
+ }
+
+ object HiveTableScans extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ // Push attributes into table scan when possible.
+ case p @ logical.Project(projectList, m: MetastoreRelation) if isSimpleProject(projectList) =>
+ HiveTableScan(projectList.asInstanceOf[Seq[Attribute]], m, None)(hiveContext) :: Nil
+ case m: MetastoreRelation =>
+ HiveTableScan(m.output, m, None)(hiveContext) :: Nil
+ case _ => Nil
+ }
+ }
+
+ /**
+ * A strategy used to detect filtering predicates on top of a partitioned relation to help
+ * partition pruning.
+ *
+ * This strategy itself doesn't perform partition pruning, it just collects and combines all the
+ * partition pruning predicates and pass them down to the underlying [[HiveTableScan]] operator,
+ * which does the actual pruning work.
+ */
+ object PartitionPrunings extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case p @ FilteredOperation(predicates, relation: MetastoreRelation)
+ if relation.isPartitioned =>
+
+ val partitionKeyIds = relation.partitionKeys.map(_.id).toSet
+
+ // Filter out all predicates that only deal with partition keys
+ val (pruningPredicates, otherPredicates) = predicates.partition {
+ _.references.map(_.id).subsetOf(partitionKeyIds)
+ }
+
+ val scan = HiveTableScan(
+ relation.output, relation, pruningPredicates.reduceLeftOption(And))(hiveContext)
+
+ otherPredicates
+ .reduceLeftOption(And)
+ .map(Filter(_, scan))
+ .getOrElse(scan) :: Nil
+
+ case _ =>
+ Nil
+ }
+ }
+
+ /**
+ * A strategy that detects projects and filters over some relation and applies column pruning if
+ * possible. Partition pruning is applied first if the relation is partitioned.
+ */
+ object ColumnPrunings extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ // TODO(andre): the current mix of HiveRelation and ParquetRelation
+ // here appears artificial; try to refactor to break it into two
+ case PhysicalOperation(projectList, predicates, relation: BaseRelation) =>
+ val predicateOpt = predicates.reduceOption(And)
+ val predicateRefs = predicateOpt.map(_.references).getOrElse(Set.empty)
+ val projectRefs = projectList.flatMap(_.references)
+
+ // To figure out what columns to preserve after column pruning, we need to consider:
+ //
+ // 1. Columns referenced by the project list (order preserved)
+ // 2. Columns referenced by filtering predicates but not by project list
+ // 3. Relation output
+ //
+ // Then the final result is ((1 union 2) intersect 3)
+ val prunedCols = (projectRefs ++ (predicateRefs -- projectRefs)).intersect(relation.output)
+
+ val filteredScans =
+ if (relation.isPartitioned) { // from here on relation must be a [[MetaStoreRelation]]
+ // Applies partition pruning first for partitioned table
+ val filteredRelation = predicateOpt.map(logical.Filter(_, relation)).getOrElse(relation)
+ PartitionPrunings(filteredRelation).view.map(_.transform {
+ case scan: HiveTableScan =>
+ scan.copy(attributes = prunedCols)(hiveContext)
+ })
+ } else {
+ val scan = relation match {
+ case MetastoreRelation(_, _, _) => {
+ HiveTableScan(
+ prunedCols,
+ relation.asInstanceOf[MetastoreRelation],
+ None)(hiveContext)
+ }
+ case ParquetRelation(_, _) => {
+ ParquetTableScan(
+ relation.output,
+ relation.asInstanceOf[ParquetRelation],
+ None)(hiveContext.sparkContext)
+ .pruneColumns(prunedCols)
+ }
+ }
+ predicateOpt.map(execution.Filter(_, scan)).getOrElse(scan) :: Nil
+ }
+
+ if (isSimpleProject(projectList) && prunedCols == projectRefs) {
+ filteredScans
+ } else {
+ filteredScans.view.map(execution.Project(projectList, _))
+ }
+
+ case _ =>
+ Nil
+ }
+ }
+
+ /**
+ * Returns true if `projectList` only performs column pruning and does not evaluate other
+ * complex expressions.
+ */
+ def isSimpleProject(projectList: Seq[NamedExpression]) = {
+ projectList.forall(_.isInstanceOf[Attribute])
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ScriptTransformation.scala
new file mode 100644
index 0000000000..f20e9d4de4
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ScriptTransformation.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+
+import java.io.{InputStreamReader, BufferedReader}
+
+import catalyst.expressions._
+import org.apache.spark.sql.execution._
+
+import scala.collection.JavaConversions._
+
+/**
+ * Transforms the input by forking and running the specified script.
+ *
+ * @param input the set of expression that should be passed to the script.
+ * @param script the command that should be executed.
+ * @param output the attributes that are produced by the script.
+ */
+case class ScriptTransformation(
+ input: Seq[Expression],
+ script: String,
+ output: Seq[Attribute],
+ child: SparkPlan)(@transient sc: HiveContext)
+ extends UnaryNode {
+
+ override def otherCopyArgs = sc :: Nil
+
+ def execute() = {
+ child.execute().mapPartitions { iter =>
+ val cmd = List("/bin/bash", "-c", script)
+ val builder = new ProcessBuilder(cmd)
+ val proc = builder.start()
+ val inputStream = proc.getInputStream
+ val outputStream = proc.getOutputStream
+ val reader = new BufferedReader(new InputStreamReader(inputStream))
+
+ // TODO: This should be exposed as an iterator instead of reading in all the data at once.
+ val outputLines = collection.mutable.ArrayBuffer[Row]()
+ val readerThread = new Thread("Transform OutputReader") {
+ override def run() {
+ var curLine = reader.readLine()
+ while (curLine != null) {
+ // TODO: Use SerDe
+ outputLines += new GenericRow(curLine.split("\t").asInstanceOf[Array[Any]])
+ curLine = reader.readLine()
+ }
+ }
+ }
+ readerThread.start()
+ val outputProjection = new Projection(input)
+ iter
+ .map(outputProjection)
+ // TODO: Use SerDe
+ .map(_.mkString("", "\t", "\n").getBytes).foreach(outputStream.write)
+ outputStream.close()
+ readerThread.join()
+ outputLines.toIterator
+ }
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
new file mode 100644
index 0000000000..71d751cbc4
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants._
+import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable}
+import org.apache.hadoop.hive.ql.plan.TableDesc
+import org.apache.hadoop.hive.serde2.Deserializer
+import org.apache.hadoop.hive.ql.exec.Utilities
+import org.apache.hadoop.io.Writable
+import org.apache.hadoop.fs.{Path, PathFilter}
+import org.apache.hadoop.mapred.{FileInputFormat, JobConf, InputFormat}
+
+import org.apache.spark.SerializableWritable
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.{HadoopRDD, UnionRDD, EmptyRDD, RDD}
+
+
+/**
+ * A trait for subclasses that handle table scans.
+ */
+private[hive] sealed trait TableReader {
+ def makeRDDForTable(hiveTable: HiveTable): RDD[_]
+
+ def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[_]
+
+}
+
+
+/**
+ * Helper class for scanning tables stored in Hadoop - e.g., to read Hive tables that reside in the
+ * data warehouse directory.
+ */
+private[hive]
+class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveContext)
+ extends TableReader {
+
+ // Choose the minimum number of splits. If mapred.map.tasks is set, then use that unless
+ // it is smaller than what Spark suggests.
+ private val _minSplitsPerRDD = math.max(
+ sc.hiveconf.getInt("mapred.map.tasks", 1), sc.sparkContext.defaultMinSplits)
+
+
+ // TODO: set aws s3 credentials.
+
+ private val _broadcastedHiveConf =
+ sc.sparkContext.broadcast(new SerializableWritable(sc.hiveconf))
+
+ def broadcastedHiveConf = _broadcastedHiveConf
+
+ def hiveConf = _broadcastedHiveConf.value.value
+
+ override def makeRDDForTable(hiveTable: HiveTable): RDD[_] =
+ makeRDDForTable(
+ hiveTable,
+ _tableDesc.getDeserializerClass.asInstanceOf[Class[Deserializer]],
+ filterOpt = None)
+
+ /**
+ * Creates a Hadoop RDD to read data from the target table's data directory. Returns a transformed
+ * RDD that contains deserialized rows.
+ *
+ * @param hiveTable Hive metadata for the table being scanned.
+ * @param deserializerClass Class of the SerDe used to deserialize Writables read from Hadoop.
+ * @param filterOpt If defined, then the filter is used to reject files contained in the data
+ * directory being read. If None, then all files are accepted.
+ */
+ def makeRDDForTable(
+ hiveTable: HiveTable,
+ deserializerClass: Class[_ <: Deserializer],
+ filterOpt: Option[PathFilter]): RDD[_] =
+ {
+ assert(!hiveTable.isPartitioned, """makeRDDForTable() cannot be called on a partitioned table,
+ since input formats may differ across partitions. Use makeRDDForTablePartitions() instead.""")
+
+ // Create local references to member variables, so that the entire `this` object won't be
+ // serialized in the closure below.
+ val tableDesc = _tableDesc
+ val broadcastedHiveConf = _broadcastedHiveConf
+
+ val tablePath = hiveTable.getPath
+ val inputPathStr = applyFilterIfNeeded(tablePath, filterOpt)
+
+ //logDebug("Table input: %s".format(tablePath))
+ val ifc = hiveTable.getInputFormatClass
+ .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]]
+ val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc)
+
+ val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter =>
+ val hconf = broadcastedHiveConf.value.value
+ val deserializer = deserializerClass.newInstance()
+ deserializer.initialize(hconf, tableDesc.getProperties)
+
+ // Deserialize each Writable to get the row value.
+ iter.map {
+ case v: Writable => deserializer.deserialize(v)
+ case value =>
+ sys.error(s"Unable to deserialize non-Writable: $value of ${value.getClass.getName}")
+ }
+ }
+ deserializedHadoopRDD
+ }
+
+ override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[_] = {
+ val partitionToDeserializer = partitions.map(part =>
+ (part, part.getDeserializer.getClass.asInstanceOf[Class[Deserializer]])).toMap
+ makeRDDForPartitionedTable(partitionToDeserializer, filterOpt = None)
+ }
+
+ /**
+ * Create a HadoopRDD for every partition key specified in the query. Note that for on-disk Hive
+ * tables, a data directory is created for each partition corresponding to keys specified using
+ * 'PARTITION BY'.
+ *
+ * @param partitionToDeserializer Mapping from a Hive Partition metadata object to the SerDe
+ * class to use to deserialize input Writables from the corresponding partition.
+ * @param filterOpt If defined, then the filter is used to reject files contained in the data
+ * subdirectory of each partition being read. If None, then all files are accepted.
+ */
+ def makeRDDForPartitionedTable(
+ partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]],
+ filterOpt: Option[PathFilter]): RDD[_] =
+ {
+ val hivePartitionRDDs = partitionToDeserializer.map { case (partition, partDeserializer) =>
+ val partDesc = Utilities.getPartitionDesc(partition)
+ val partPath = partition.getPartitionPath
+ val inputPathStr = applyFilterIfNeeded(partPath, filterOpt)
+ val ifc = partDesc.getInputFileFormatClass
+ .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]]
+ // Get partition field info
+ val partSpec = partDesc.getPartSpec
+ val partProps = partDesc.getProperties
+
+ val partColsDelimited: String = partProps.getProperty(META_TABLE_PARTITION_COLUMNS)
+ // Partitioning columns are delimited by "/"
+ val partCols = partColsDelimited.trim().split("/").toSeq
+ // 'partValues[i]' contains the value for the partitioning column at 'partCols[i]'.
+ val partValues = if (partSpec == null) {
+ Array.fill(partCols.size)(new String)
+ } else {
+ partCols.map(col => new String(partSpec.get(col))).toArray
+ }
+
+ // Create local references so that the outer object isn't serialized.
+ val tableDesc = _tableDesc
+ val broadcastedHiveConf = _broadcastedHiveConf
+ val localDeserializer = partDeserializer
+
+ val hivePartitionRDD = createHadoopRdd(tableDesc, inputPathStr, ifc)
+ hivePartitionRDD.mapPartitions { iter =>
+ val hconf = broadcastedHiveConf.value.value
+ val rowWithPartArr = new Array[Object](2)
+ // Map each tuple to a row object
+ iter.map { value =>
+ val deserializer = localDeserializer.newInstance()
+ deserializer.initialize(hconf, partProps)
+ val deserializedRow = deserializer.deserialize(value)
+ rowWithPartArr.update(0, deserializedRow)
+ rowWithPartArr.update(1, partValues)
+ rowWithPartArr.asInstanceOf[Object]
+ }
+ }
+ }.toSeq
+ // Even if we don't use any partitions, we still need an empty RDD
+ if (hivePartitionRDDs.size == 0) {
+ new EmptyRDD[Object](sc.sparkContext)
+ } else {
+ new UnionRDD(hivePartitionRDDs(0).context, hivePartitionRDDs)
+ }
+ }
+
+ /**
+ * If `filterOpt` is defined, then it will be used to filter files from `path`. These files are
+ * returned in a single, comma-separated string.
+ */
+ private def applyFilterIfNeeded(path: Path, filterOpt: Option[PathFilter]): String = {
+ filterOpt match {
+ case Some(filter) =>
+ val fs = path.getFileSystem(sc.hiveconf)
+ val filteredFiles = fs.listStatus(path, filter).map(_.getPath.toString)
+ filteredFiles.mkString(",")
+ case None => path.toString
+ }
+ }
+
+ /**
+ * Creates a HadoopRDD based on the broadcasted HiveConf and other job properties that will be
+ * applied locally on each slave.
+ */
+ private def createHadoopRdd(
+ tableDesc: TableDesc,
+ path: String,
+ inputFormatClass: Class[InputFormat[Writable, Writable]])
+ : RDD[Writable] = {
+ val initializeJobConfFunc = HadoopTableReader.initializeLocalJobConfFunc(path, tableDesc) _
+
+ val rdd = new HadoopRDD(
+ sc.sparkContext,
+ _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableWritable[Configuration]]],
+ Some(initializeJobConfFunc),
+ inputFormatClass,
+ classOf[Writable],
+ classOf[Writable],
+ _minSplitsPerRDD)
+
+ // Only take the value (skip the key) because Hive works only with values.
+ rdd.map(_._2)
+ }
+
+}
+
+private[hive] object HadoopTableReader {
+
+ /**
+ * Curried. After given an argument for 'path', the resulting JobConf => Unit closure is used to
+ * instantiate a HadoopRDD.
+ */
+ def initializeLocalJobConfFunc(path: String, tableDesc: TableDesc)(jobConf: JobConf) {
+ FileInputFormat.setInputPaths(jobConf, path)
+ if (tableDesc != null) {
+ Utilities.copyTableJobPropertiesToConf(tableDesc, jobConf)
+ }
+ val bufferSize = System.getProperty("spark.buffer.size", "65536")
+ jobConf.set("io.file.buffer.size", bufferSize)
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
new file mode 100644
index 0000000000..17ae4ef63c
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
@@ -0,0 +1,341 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+
+import java.io.File
+import java.util.{Set => JavaSet}
+
+import scala.collection.mutable
+import scala.collection.JavaConversions._
+import scala.language.implicitConversions
+
+import org.apache.hadoop.hive.metastore.api.{SerDeInfo, StorageDescriptor}
+import org.apache.hadoop.hive.metastore.MetaStoreUtils
+import org.apache.hadoop.hive.ql.exec.FunctionRegistry
+import org.apache.hadoop.hive.ql.io.avro.{AvroContainerOutputFormat, AvroContainerInputFormat}
+import org.apache.hadoop.hive.ql.metadata.Table
+import org.apache.hadoop.hive.serde2.avro.AvroSerDe
+import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
+import org.apache.hadoop.hive.serde2.RegexSerDe
+
+import org.apache.spark.{SparkContext, SparkConf}
+
+import catalyst.analysis._
+import catalyst.plans.logical.{LogicalPlan, NativeCommand}
+import catalyst.util._
+
+object TestHive
+ extends TestHiveContext(new SparkContext("local", "TestSQLContext", new SparkConf()))
+
+/**
+ * A locally running test instance of Spark's Hive execution engine.
+ *
+ * Data from [[testTables]] will be automatically loaded whenever a query is run over those tables.
+ * Calling [[reset]] will delete all tables and other state in the database, leaving the database
+ * in a "clean" state.
+ *
+ * TestHive is singleton object version of this class because instantiating multiple copies of the
+ * hive metastore seems to lead to weird non-deterministic failures. Therefore, the execution of
+ * testcases that rely on TestHive must be serialized.
+ */
+class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) {
+ self =>
+
+ // By clearing the port we force Spark to pick a new one. This allows us to rerun tests
+ // without restarting the JVM.
+ System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
+
+ override lazy val warehousePath = getTempFilePath("sparkHiveWarehouse").getCanonicalPath
+ override lazy val metastorePath = getTempFilePath("sparkHiveMetastore").getCanonicalPath
+
+ /** The location of the compiled hive distribution */
+ lazy val hiveHome = envVarToFile("HIVE_HOME")
+ /** The location of the hive source code. */
+ lazy val hiveDevHome = envVarToFile("HIVE_DEV_HOME")
+
+ // Override so we can intercept relative paths and rewrite them to point at hive.
+ override def runSqlHive(sql: String): Seq[String] = super.runSqlHive(rewritePaths(sql))
+
+ override def executePlan(plan: LogicalPlan): this.QueryExecution =
+ new this.QueryExecution { val logical = plan }
+
+ /**
+ * Returns the value of specified environmental variable as a [[java.io.File]] after checking
+ * to ensure it exists
+ */
+ private def envVarToFile(envVar: String): Option[File] = {
+ Option(System.getenv(envVar)).map(new File(_))
+ }
+
+ /**
+ * Replaces relative paths to the parent directory "../" with hiveDevHome since this is how the
+ * hive test cases assume the system is set up.
+ */
+ private def rewritePaths(cmd: String): String =
+ if (cmd.toUpperCase contains "LOAD DATA") {
+ val testDataLocation =
+ hiveDevHome.map(_.getCanonicalPath).getOrElse(inRepoTests.getCanonicalPath)
+ cmd.replaceAll("\\.\\.", testDataLocation)
+ } else {
+ cmd
+ }
+
+ val hiveFilesTemp = File.createTempFile("catalystHiveFiles", "")
+ hiveFilesTemp.delete()
+ hiveFilesTemp.mkdir()
+
+ val inRepoTests = new File("src/test/resources/")
+ def getHiveFile(path: String): File = {
+ val stripped = path.replaceAll("""\.\.\/""", "")
+ hiveDevHome
+ .map(new File(_, stripped))
+ .filter(_.exists)
+ .getOrElse(new File(inRepoTests, stripped))
+ }
+
+ val describedTable = "DESCRIBE (\\w+)".r
+
+ class SqlQueryExecution(sql: String) extends this.QueryExecution {
+ lazy val logical = HiveQl.parseSql(sql)
+ def hiveExec() = runSqlHive(sql)
+ override def toString = sql + "\n" + super.toString
+ }
+
+ /**
+ * Override QueryExecution with special debug workflow.
+ */
+ abstract class QueryExecution extends super.QueryExecution {
+ override lazy val analyzed = {
+ val describedTables = logical match {
+ case NativeCommand(describedTable(tbl)) => tbl :: Nil
+ case _ => Nil
+ }
+
+ // Make sure any test tables referenced are loaded.
+ val referencedTables =
+ describedTables ++
+ logical.collect { case UnresolvedRelation(databaseName, name, _) => name }
+ val referencedTestTables = referencedTables.filter(testTables.contains)
+ logger.debug(s"Query references test tables: ${referencedTestTables.mkString(", ")}")
+ referencedTestTables.foreach(loadTestTable)
+ // Proceed with analysis.
+ analyzer(logical)
+ }
+ }
+
+ case class TestTable(name: String, commands: (()=>Unit)*)
+
+ implicit class SqlCmd(sql: String) {
+ def cmd = () => new SqlQueryExecution(sql).stringResult(): Unit
+ }
+
+ /**
+ * A list of test tables and the DDL required to initialize them. A test table is loaded on
+ * demand when a query are run against it.
+ */
+ lazy val testTables = new mutable.HashMap[String, TestTable]()
+ def registerTestTable(testTable: TestTable) = testTables += (testTable.name -> testTable)
+
+ // The test tables that are defined in the Hive QTestUtil.
+ // /itests/util/src/main/java/org/apache/hadoop/hive/ql/QTestUtil.java
+ val hiveQTestUtilTables = Seq(
+ TestTable("src",
+ "CREATE TABLE src (key INT, value STRING)".cmd,
+ s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd),
+ TestTable("src1",
+ "CREATE TABLE src1 (key INT, value STRING)".cmd,
+ s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd),
+ TestTable("dest1",
+ "CREATE TABLE IF NOT EXISTS dest1 (key INT, value STRING)".cmd),
+ TestTable("dest2",
+ "CREATE TABLE IF NOT EXISTS dest2 (key INT, value STRING)".cmd),
+ TestTable("dest3",
+ "CREATE TABLE IF NOT EXISTS dest3 (key INT, value STRING)".cmd),
+ TestTable("srcpart", () => {
+ runSqlHive(
+ "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)")
+ for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) {
+ runSqlHive(
+ s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}'
+ |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr')
+ """.stripMargin)
+ }
+ }),
+ TestTable("srcpart1", () => {
+ runSqlHive("CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)")
+ for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) {
+ runSqlHive(
+ s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}'
+ |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr')
+ """.stripMargin)
+ }
+ }),
+ TestTable("src_thrift", () => {
+ import org.apache.thrift.protocol.TBinaryProtocol
+ import org.apache.hadoop.hive.serde2.thrift.test.Complex
+ import org.apache.hadoop.hive.serde2.thrift.ThriftDeserializer
+ import org.apache.hadoop.mapred.SequenceFileInputFormat
+ import org.apache.hadoop.mapred.SequenceFileOutputFormat
+
+ val srcThrift = new Table("default", "src_thrift")
+ srcThrift.setFields(Nil)
+ srcThrift.setInputFormatClass(classOf[SequenceFileInputFormat[_,_]].getName)
+ // In Hive, SequenceFileOutputFormat will be substituted by HiveSequenceFileOutputFormat.
+ srcThrift.setOutputFormatClass(classOf[SequenceFileOutputFormat[_,_]].getName)
+ srcThrift.setSerializationLib(classOf[ThriftDeserializer].getName)
+ srcThrift.setSerdeParam("serialization.class", classOf[Complex].getName)
+ srcThrift.setSerdeParam("serialization.format", classOf[TBinaryProtocol].getName)
+ catalog.client.createTable(srcThrift)
+
+
+ runSqlHive(
+ s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/complex.seq")}' INTO TABLE src_thrift")
+ }),
+ TestTable("serdeins",
+ s"""CREATE TABLE serdeins (key INT, value STRING)
+ |ROW FORMAT SERDE '${classOf[LazySimpleSerDe].getCanonicalName}'
+ |WITH SERDEPROPERTIES ('field.delim'='\\t')
+ """.stripMargin.cmd,
+ "INSERT OVERWRITE TABLE serdeins SELECT * FROM src".cmd),
+ TestTable("sales",
+ s"""CREATE TABLE IF NOT EXISTS sales (key STRING, value INT)
+ |ROW FORMAT SERDE '${classOf[RegexSerDe].getCanonicalName}'
+ |WITH SERDEPROPERTIES ("input.regex" = "([^ ]*)\t([^ ]*)")
+ """.stripMargin.cmd,
+ s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/sales.txt")}' INTO TABLE sales".cmd),
+ TestTable("episodes",
+ s"""CREATE TABLE episodes (title STRING, air_date STRING, doctor INT)
+ |ROW FORMAT SERDE '${classOf[AvroSerDe].getCanonicalName}'
+ |STORED AS
+ |INPUTFORMAT '${classOf[AvroContainerInputFormat].getCanonicalName}'
+ |OUTPUTFORMAT '${classOf[AvroContainerOutputFormat].getCanonicalName}'
+ |TBLPROPERTIES (
+ | 'avro.schema.literal'='{
+ | "type": "record",
+ | "name": "episodes",
+ | "namespace": "testing.hive.avro.serde",
+ | "fields": [
+ | {
+ | "name": "title",
+ | "type": "string",
+ | "doc": "episode title"
+ | },
+ | {
+ | "name": "air_date",
+ | "type": "string",
+ | "doc": "initial date"
+ | },
+ | {
+ | "name": "doctor",
+ | "type": "int",
+ | "doc": "main actor playing the Doctor in episode"
+ | }
+ | ]
+ | }'
+ |)
+ """.stripMargin.cmd,
+ s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/episodes.avro")}' INTO TABLE episodes".cmd
+ )
+ )
+
+ hiveQTestUtilTables.foreach(registerTestTable)
+
+ private val loadedTables = new collection.mutable.HashSet[String]
+
+ def loadTestTable(name: String) {
+ if (!(loadedTables contains name)) {
+ // Marks the table as loaded first to prevent infite mutually recursive table loading.
+ loadedTables += name
+ logger.info(s"Loading test table $name")
+ val createCmds =
+ testTables.get(name).map(_.commands).getOrElse(sys.error(s"Unknown test table $name"))
+ createCmds.foreach(_())
+ }
+ }
+
+ /**
+ * Records the UDFs present when the server starts, so we can delete ones that are created by
+ * tests.
+ */
+ protected val originalUdfs: JavaSet[String] = FunctionRegistry.getFunctionNames
+
+ /**
+ * Resets the test instance by deleting any tables that have been created.
+ * TODO: also clear out UDFs, views, etc.
+ */
+ def reset() {
+ try {
+ // HACK: Hive is too noisy by default.
+ org.apache.log4j.LogManager.getCurrentLoggers.foreach { logger =>
+ logger.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN)
+ }
+
+ // It is important that we RESET first as broken hooks that might have been set could break
+ // other sql exec here.
+ runSqlHive("RESET")
+ // For some reason, RESET does not reset the following variables...
+ runSqlHive("set datanucleus.cache.collections=true")
+ runSqlHive("set datanucleus.cache.collections.lazy=true")
+ // Lots of tests fail if we do not change the partition whitelist from the default.
+ runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*")
+
+ loadedTables.clear()
+ catalog.client.getAllTables("default").foreach { t =>
+ logger.debug(s"Deleting table $t")
+ val table = catalog.client.getTable("default", t)
+
+ catalog.client.getIndexes("default", t, 255).foreach { index =>
+ catalog.client.dropIndex("default", t, index.getIndexName, true)
+ }
+
+ if (!table.isIndexTable) {
+ catalog.client.dropTable("default", t)
+ }
+ }
+
+ catalog.client.getAllDatabases.filterNot(_ == "default").foreach { db =>
+ logger.debug(s"Dropping Database: $db")
+ catalog.client.dropDatabase(db, true, false, true)
+ }
+
+ FunctionRegistry.getFunctionNames.filterNot(originalUdfs.contains(_)).foreach { udfName =>
+ FunctionRegistry.unregisterTemporaryUDF(udfName)
+ }
+
+ configure()
+
+ runSqlHive("USE default")
+
+ // Just loading src makes a lot of tests pass. This is because some tests do something like
+ // drop an index on src at the beginning. Since we just pass DDL to hive this bypasses our
+ // Analyzer and thus the test table auto-loading mechanism.
+ // Remove after we handle more DDL operations natively.
+ loadTestTable("src")
+ loadTestTable("srcpart")
+ } catch {
+ case e: Exception =>
+ logger.error(s"FATAL ERROR: Failed to reset TestDB state. $e")
+ // At this point there is really no reason to continue, but the test framework traps exits.
+ // So instead we just pause forever so that at least the developer can see where things
+ // started to go wrong.
+ Thread.sleep(100000)
+ }
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala
new file mode 100644
index 0000000000..d20fd87f34
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala
@@ -0,0 +1,356 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+
+import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar}
+import org.apache.hadoop.hive.metastore.MetaStoreUtils
+import org.apache.hadoop.hive.ql.Context
+import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Hive}
+import org.apache.hadoop.hive.ql.plan.{TableDesc, FileSinkDesc}
+import org.apache.hadoop.hive.serde2.Serializer
+import org.apache.hadoop.hive.serde2.objectinspector._
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveDecimalObjectInspector
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveVarcharObjectInspector
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.Writable
+import org.apache.hadoop.mapred._
+
+import catalyst.expressions._
+import catalyst.types.{BooleanType, DataType}
+import org.apache.spark.{TaskContext, SparkException}
+import catalyst.expressions.Cast
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.execution._
+
+import scala.Some
+import scala.collection.immutable.ListMap
+
+/* Implicits */
+import scala.collection.JavaConversions._
+
+/**
+ * The Hive table scan operator. Column and partition pruning are both handled.
+ *
+ * @constructor
+ * @param attributes Attributes to be fetched from the Hive table.
+ * @param relation The Hive table be be scanned.
+ * @param partitionPruningPred An optional partition pruning predicate for partitioned table.
+ */
+case class HiveTableScan(
+ attributes: Seq[Attribute],
+ relation: MetastoreRelation,
+ partitionPruningPred: Option[Expression])(
+ @transient val sc: HiveContext)
+ extends LeafNode
+ with HiveInspectors {
+
+ require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned,
+ "Partition pruning predicates only supported for partitioned tables.")
+
+ // Bind all partition key attribute references in the partition pruning predicate for later
+ // evaluation.
+ private val boundPruningPred = partitionPruningPred.map { pred =>
+ require(
+ pred.dataType == BooleanType,
+ s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.")
+
+ BindReferences.bindReference(pred, relation.partitionKeys)
+ }
+
+ @transient
+ val hadoopReader = new HadoopTableReader(relation.tableDesc, sc)
+
+ /**
+ * The hive object inspector for this table, which can be used to extract values from the
+ * serialized row representation.
+ */
+ @transient
+ lazy val objectInspector =
+ relation.tableDesc.getDeserializer.getObjectInspector.asInstanceOf[StructObjectInspector]
+
+ /**
+ * Functions that extract the requested attributes from the hive output. Partitioned values are
+ * casted from string to its declared data type.
+ */
+ @transient
+ protected lazy val attributeFunctions: Seq[(Any, Array[String]) => Any] = {
+ attributes.map { a =>
+ val ordinal = relation.partitionKeys.indexOf(a)
+ if (ordinal >= 0) {
+ (_: Any, partitionKeys: Array[String]) => {
+ val value = partitionKeys(ordinal)
+ val dataType = relation.partitionKeys(ordinal).dataType
+ castFromString(value, dataType)
+ }
+ } else {
+ val ref = objectInspector.getAllStructFieldRefs
+ .find(_.getFieldName == a.name)
+ .getOrElse(sys.error(s"Can't find attribute $a"))
+ (row: Any, _: Array[String]) => {
+ val data = objectInspector.getStructFieldData(row, ref)
+ unwrapData(data, ref.getFieldObjectInspector)
+ }
+ }
+ }
+ }
+
+ private def castFromString(value: String, dataType: DataType) = {
+ Cast(Literal(value), dataType).apply(null)
+ }
+
+ @transient
+ def inputRdd = if (!relation.hiveQlTable.isPartitioned) {
+ hadoopReader.makeRDDForTable(relation.hiveQlTable)
+ } else {
+ hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions))
+ }
+
+ /**
+ * Prunes partitions not involve the query plan.
+ *
+ * @param partitions All partitions of the relation.
+ * @return Partitions that are involved in the query plan.
+ */
+ private[hive] def prunePartitions(partitions: Seq[HivePartition]) = {
+ boundPruningPred match {
+ case None => partitions
+ case Some(shouldKeep) => partitions.filter { part =>
+ val dataTypes = relation.partitionKeys.map(_.dataType)
+ val castedValues = for ((value, dataType) <- part.getValues.zip(dataTypes)) yield {
+ castFromString(value, dataType)
+ }
+
+ // Only partitioned values are needed here, since the predicate has already been bound to
+ // partition key attribute references.
+ val row = new GenericRow(castedValues.toArray)
+ shouldKeep.apply(row).asInstanceOf[Boolean]
+ }
+ }
+ }
+
+ def execute() = {
+ inputRdd.map { row =>
+ val values = row match {
+ case Array(deserializedRow: AnyRef, partitionKeys: Array[String]) =>
+ attributeFunctions.map(_(deserializedRow, partitionKeys))
+ case deserializedRow: AnyRef =>
+ attributeFunctions.map(_(deserializedRow, Array.empty))
+ }
+ buildRow(values.map {
+ case n: String if n.toLowerCase == "null" => null
+ case varchar: org.apache.hadoop.hive.common.`type`.HiveVarchar => varchar.getValue
+ case decimal: org.apache.hadoop.hive.common.`type`.HiveDecimal =>
+ BigDecimal(decimal.bigDecimalValue)
+ case other => other
+ })
+ }
+ }
+
+ def output = attributes
+}
+
+case class InsertIntoHiveTable(
+ table: MetastoreRelation,
+ partition: Map[String, Option[String]],
+ child: SparkPlan,
+ overwrite: Boolean)
+ (@transient sc: HiveContext)
+ extends UnaryNode {
+
+ val outputClass = newSerializer(table.tableDesc).getSerializedClass
+ @transient private val hiveContext = new Context(sc.hiveconf)
+ @transient private val db = Hive.get(sc.hiveconf)
+
+ private def newSerializer(tableDesc: TableDesc): Serializer = {
+ val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer]
+ serializer.initialize(null, tableDesc.getProperties)
+ serializer
+ }
+
+ override def otherCopyArgs = sc :: Nil
+
+ def output = child.output
+
+ /**
+ * Wraps with Hive types based on object inspector.
+ * TODO: Consolidate all hive OI/data interface code.
+ */
+ protected def wrap(a: (Any, ObjectInspector)): Any = a match {
+ case (s: String, oi: JavaHiveVarcharObjectInspector) => new HiveVarchar(s, s.size)
+ case (bd: BigDecimal, oi: JavaHiveDecimalObjectInspector) =>
+ new HiveDecimal(bd.underlying())
+ case (row: Row, oi: StandardStructObjectInspector) =>
+ val struct = oi.create()
+ row.zip(oi.getAllStructFieldRefs).foreach {
+ case (data, field) =>
+ oi.setStructFieldData(struct, field, wrap(data, field.getFieldObjectInspector))
+ }
+ struct
+ case (s: Seq[_], oi: ListObjectInspector) =>
+ val wrappedSeq = s.map(wrap(_, oi.getListElementObjectInspector))
+ seqAsJavaList(wrappedSeq)
+ case (obj, _) => obj
+ }
+
+ def saveAsHiveFile(
+ rdd: RDD[Writable],
+ valueClass: Class[_],
+ fileSinkConf: FileSinkDesc,
+ conf: JobConf,
+ isCompressed: Boolean) {
+ if (valueClass == null) {
+ throw new SparkException("Output value class not set")
+ }
+ conf.setOutputValueClass(valueClass)
+ if (fileSinkConf.getTableInfo.getOutputFileFormatClassName == null) {
+ throw new SparkException("Output format class not set")
+ }
+ // Doesn't work in Scala 2.9 due to what may be a generics bug
+ // TODO: Should we uncomment this for Scala 2.10?
+ // conf.setOutputFormat(outputFormatClass)
+ conf.set("mapred.output.format.class", fileSinkConf.getTableInfo.getOutputFileFormatClassName)
+ if (isCompressed) {
+ // Please note that isCompressed, "mapred.output.compress", "mapred.output.compression.codec",
+ // and "mapred.output.compression.type" have no impact on ORC because it uses table properties
+ // to store compression information.
+ conf.set("mapred.output.compress", "true")
+ fileSinkConf.setCompressed(true)
+ fileSinkConf.setCompressCodec(conf.get("mapred.output.compression.codec"))
+ fileSinkConf.setCompressType(conf.get("mapred.output.compression.type"))
+ }
+ conf.setOutputCommitter(classOf[FileOutputCommitter])
+ FileOutputFormat.setOutputPath(
+ conf,
+ SparkHiveHadoopWriter.createPathFromString(fileSinkConf.getDirName, conf))
+
+ logger.debug("Saving as hadoop file of type " + valueClass.getSimpleName)
+
+ val writer = new SparkHiveHadoopWriter(conf, fileSinkConf)
+ writer.preSetup()
+
+ def writeToFile(context: TaskContext, iter: Iterator[Writable]) {
+ // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
+ // around by taking a mod. We expect that no task will be attempted 2 billion times.
+ val attemptNumber = (context.attemptId % Int.MaxValue).toInt
+
+ writer.setup(context.stageId, context.partitionId, attemptNumber)
+ writer.open()
+
+ var count = 0
+ while(iter.hasNext) {
+ val record = iter.next()
+ count += 1
+ writer.write(record)
+ }
+
+ writer.close()
+ writer.commit()
+ }
+
+ sc.sparkContext.runJob(rdd, writeToFile _)
+ writer.commitJob()
+ }
+
+ /**
+ * Inserts all the rows in the table into Hive. Row objects are properly serialized with the
+ * `org.apache.hadoop.hive.serde2.SerDe` and the
+ * `org.apache.hadoop.mapred.OutputFormat` provided by the table definition.
+ */
+ def execute() = {
+ val childRdd = child.execute()
+ assert(childRdd != null)
+
+ // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer
+ // instances within the closure, since Serializer is not serializable while TableDesc is.
+ val tableDesc = table.tableDesc
+ val tableLocation = table.hiveQlTable.getDataLocation
+ val tmpLocation = hiveContext.getExternalTmpFileURI(tableLocation)
+ val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false)
+ val rdd = childRdd.mapPartitions { iter =>
+ val serializer = newSerializer(fileSinkConf.getTableInfo)
+ val standardOI = ObjectInspectorUtils
+ .getStandardObjectInspector(
+ fileSinkConf.getTableInfo.getDeserializer.getObjectInspector,
+ ObjectInspectorCopyOption.JAVA)
+ .asInstanceOf[StructObjectInspector]
+
+ iter.map { row =>
+ // Casts Strings to HiveVarchars when necessary.
+ val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector)
+ val mappedRow = row.zip(fieldOIs).map(wrap)
+
+ serializer.serialize(mappedRow.toArray, standardOI)
+ }
+ }
+
+ // ORC stores compression information in table properties. While, there are other formats
+ // (e.g. RCFile) that rely on hadoop configurations to store compression information.
+ val jobConf = new JobConf(sc.hiveconf)
+ saveAsHiveFile(
+ rdd,
+ outputClass,
+ fileSinkConf,
+ jobConf,
+ sc.hiveconf.getBoolean("hive.exec.compress.output", false))
+
+ // TODO: Handle dynamic partitioning.
+ val outputPath = FileOutputFormat.getOutputPath(jobConf)
+ // Have to construct the format of dbname.tablename.
+ val qualifiedTableName = s"${table.databaseName}.${table.tableName}"
+ // TODO: Correctly set holdDDLTime.
+ // In most of the time, we should have holdDDLTime = false.
+ // holdDDLTime will be true when TOK_HOLD_DDLTIME presents in the query as a hint.
+ val holdDDLTime = false
+ if (partition.nonEmpty) {
+ val partitionSpec = partition.map {
+ case (key, Some(value)) => key -> value
+ case (key, None) => key -> "" // Should not reach here right now.
+ }
+ val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols(), partitionSpec)
+ db.validatePartitionNameCharacters(partVals)
+ // inheritTableSpecs is set to true. It should be set to false for a IMPORT query
+ // which is currently considered as a Hive native command.
+ val inheritTableSpecs = true
+ // TODO: Correctly set isSkewedStoreAsSubdir.
+ val isSkewedStoreAsSubdir = false
+ db.loadPartition(
+ outputPath,
+ qualifiedTableName,
+ partitionSpec,
+ overwrite,
+ holdDDLTime,
+ inheritTableSpecs,
+ isSkewedStoreAsSubdir)
+ } else {
+ db.loadTable(
+ outputPath,
+ qualifiedTableName,
+ overwrite,
+ holdDDLTime)
+ }
+
+ // It would be nice to just return the childRdd unchanged so insert operations could be chained,
+ // however for now we return an empty list to simplify compatibility checks with hive, which
+ // does not return anything for insert operations.
+ // TODO: implement hive compatibility as rules.
+ sc.sparkContext.makeRDD(Nil, 1)
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
new file mode 100644
index 0000000000..5e775d6a04
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -0,0 +1,467 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.hadoop.hive.common.`type`.HiveDecimal
+import org.apache.hadoop.hive.serde2.{io => hiveIo}
+import org.apache.hadoop.hive.serde2.objectinspector.primitive._
+import org.apache.hadoop.hive.serde2.objectinspector._
+import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry}
+import org.apache.hadoop.hive.ql.udf.generic._
+import org.apache.hadoop.hive.ql.exec.UDF
+import org.apache.hadoop.{io => hadoopIo}
+
+import catalyst.analysis
+import catalyst.expressions._
+import catalyst.types
+import catalyst.types._
+
+object HiveFunctionRegistry
+ extends analysis.FunctionRegistry with HiveFunctionFactory with HiveInspectors {
+
+ def lookupFunction(name: String, children: Seq[Expression]): Expression = {
+ // We only look it up to see if it exists, but do not include it in the HiveUDF since it is
+ // not always serializable.
+ val functionInfo: FunctionInfo = Option(FunctionRegistry.getFunctionInfo(name)).getOrElse(
+ sys.error(s"Couldn't find function $name"))
+
+ if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
+ val function = createFunction[UDF](name)
+ val method = function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo))
+
+ lazy val expectedDataTypes = method.getParameterTypes.map(javaClassToDataType)
+
+ HiveSimpleUdf(
+ name,
+ children.zip(expectedDataTypes).map { case (e, t) => Cast(e, t) }
+ )
+ } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
+ HiveGenericUdf(name, children)
+ } else if (
+ classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
+ HiveGenericUdaf(name, children)
+
+ } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
+ HiveGenericUdtf(name, Nil, children)
+ } else {
+ sys.error(s"No handler for udf ${functionInfo.getFunctionClass}")
+ }
+ }
+
+ def javaClassToDataType(clz: Class[_]): DataType = clz match {
+ case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType
+ case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType
+ case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType
+ case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType
+ case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType
+ case c: Class[_] if c == classOf[hadoopIo.Text] => StringType
+ case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType
+ case c: Class[_] if c == classOf[hadoopIo.LongWritable] => LongType
+ case c: Class[_] if c == classOf[hadoopIo.FloatWritable] => FloatType
+ case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType
+ case c: Class[_] if c == classOf[java.lang.String] => StringType
+ case c: Class[_] if c == java.lang.Short.TYPE => ShortType
+ case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType
+ case c: Class[_] if c == java.lang.Long.TYPE => LongType
+ case c: Class[_] if c == java.lang.Double.TYPE => DoubleType
+ case c: Class[_] if c == java.lang.Byte.TYPE => ByteType
+ case c: Class[_] if c == java.lang.Float.TYPE => FloatType
+ case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType
+ case c: Class[_] if c == classOf[java.lang.Short] => ShortType
+ case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
+ case c: Class[_] if c == classOf[java.lang.Long] => LongType
+ case c: Class[_] if c == classOf[java.lang.Double] => DoubleType
+ case c: Class[_] if c == classOf[java.lang.Byte] => ByteType
+ case c: Class[_] if c == classOf[java.lang.Float] => FloatType
+ case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType
+ case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType))
+ }
+}
+
+trait HiveFunctionFactory {
+ def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name)
+ def getFunctionClass(name: String) = getFunctionInfo(name).getFunctionClass
+ def createFunction[UDFType](name: String) =
+ getFunctionClass(name).newInstance.asInstanceOf[UDFType]
+
+ /** Converts hive types to native catalyst types. */
+ def unwrap(a: Any): Any = a match {
+ case null => null
+ case i: hadoopIo.IntWritable => i.get
+ case t: hadoopIo.Text => t.toString
+ case l: hadoopIo.LongWritable => l.get
+ case d: hadoopIo.DoubleWritable => d.get()
+ case d: hiveIo.DoubleWritable => d.get
+ case s: hiveIo.ShortWritable => s.get
+ case b: hadoopIo.BooleanWritable => b.get()
+ case b: hiveIo.ByteWritable => b.get
+ case list: java.util.List[_] => list.map(unwrap)
+ case map: java.util.Map[_,_] => map.map { case (k, v) => (unwrap(k), unwrap(v)) }.toMap
+ case array: Array[_] => array.map(unwrap).toSeq
+ case p: java.lang.Short => p
+ case p: java.lang.Long => p
+ case p: java.lang.Float => p
+ case p: java.lang.Integer => p
+ case p: java.lang.Double => p
+ case p: java.lang.Byte => p
+ case p: java.lang.Boolean => p
+ case str: String => str
+ }
+}
+
+abstract class HiveUdf
+ extends Expression with Logging with HiveFunctionFactory {
+ self: Product =>
+
+ type UDFType
+ type EvaluatedType = Any
+
+ val name: String
+
+ def nullable = true
+ def references = children.flatMap(_.references).toSet
+
+ // FunctionInfo is not serializable so we must look it up here again.
+ lazy val functionInfo = getFunctionInfo(name)
+ lazy val function = createFunction[UDFType](name)
+
+ override def toString = s"${nodeName}#${functionInfo.getDisplayName}(${children.mkString(",")})"
+}
+
+case class HiveSimpleUdf(name: String, children: Seq[Expression]) extends HiveUdf {
+ import HiveFunctionRegistry._
+ type UDFType = UDF
+
+ @transient
+ protected lazy val method =
+ function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo))
+
+ @transient
+ lazy val dataType = javaClassToDataType(method.getReturnType)
+
+ protected lazy val wrappers: Array[(Any) => AnyRef] = method.getParameterTypes.map { argClass =>
+ val primitiveClasses = Seq(
+ Integer.TYPE, classOf[java.lang.Integer], classOf[java.lang.String], java.lang.Double.TYPE,
+ classOf[java.lang.Double], java.lang.Long.TYPE, classOf[java.lang.Long],
+ classOf[HiveDecimal], java.lang.Byte.TYPE, classOf[java.lang.Byte]
+ )
+ val matchingConstructor = argClass.getConstructors.find { c =>
+ c.getParameterTypes.size == 1 && primitiveClasses.contains(c.getParameterTypes.head)
+ }
+
+ val constructor = matchingConstructor.getOrElse(
+ sys.error(s"No matching wrapper found, options: ${argClass.getConstructors.toSeq}."))
+
+ (a: Any) => {
+ logger.debug(
+ s"Wrapping $a of type ${if (a == null) "null" else a.getClass.getName} using $constructor.")
+ // We must make sure that primitives get boxed java style.
+ if (a == null) {
+ null
+ } else {
+ constructor.newInstance(a match {
+ case i: Int => i: java.lang.Integer
+ case bd: BigDecimal => new HiveDecimal(bd.underlying())
+ case other: AnyRef => other
+ }).asInstanceOf[AnyRef]
+ }
+ }
+ }
+
+ // TODO: Finish input output types.
+ override def apply(input: Row): Any = {
+ val evaluatedChildren = children.map(_.apply(input))
+ // Wrap the function arguments in the expected types.
+ val args = evaluatedChildren.zip(wrappers).map {
+ case (arg, wrapper) => wrapper(arg)
+ }
+
+ // Invoke the udf and unwrap the result.
+ unwrap(method.invoke(function, args: _*))
+ }
+}
+
+case class HiveGenericUdf(
+ name: String,
+ children: Seq[Expression]) extends HiveUdf with HiveInspectors {
+ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._
+ type UDFType = GenericUDF
+
+ @transient
+ protected lazy val argumentInspectors = children.map(_.dataType).map(toInspector)
+
+ @transient
+ protected lazy val returnInspector = function.initialize(argumentInspectors.toArray)
+
+ val dataType: DataType = inspectorToDataType(returnInspector)
+
+ override def apply(input: Row): Any = {
+ returnInspector // Make sure initialized.
+ val args = children.map { v =>
+ new DeferredObject {
+ override def prepare(i: Int) = {}
+ override def get(): AnyRef = wrap(v.apply(input))
+ }
+ }.toArray
+ unwrap(function.evaluate(args))
+ }
+}
+
+trait HiveInspectors {
+
+ def unwrapData(data: Any, oi: ObjectInspector): Any = oi match {
+ case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data)
+ case li: ListObjectInspector =>
+ Option(li.getList(data))
+ .map(_.map(unwrapData(_, li.getListElementObjectInspector)).toSeq)
+ .orNull
+ case mi: MapObjectInspector =>
+ Option(mi.getMap(data)).map(
+ _.map {
+ case (k,v) =>
+ (unwrapData(k, mi.getMapKeyObjectInspector),
+ unwrapData(v, mi.getMapValueObjectInspector))
+ }.toMap).orNull
+ case si: StructObjectInspector =>
+ val allRefs = si.getAllStructFieldRefs
+ new GenericRow(
+ allRefs.map(r =>
+ unwrapData(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray)
+ }
+
+ /** Converts native catalyst types to the types expected by Hive */
+ def wrap(a: Any): AnyRef = a match {
+ case s: String => new hadoopIo.Text(s)
+ case i: Int => i: java.lang.Integer
+ case b: Boolean => b: java.lang.Boolean
+ case d: Double => d: java.lang.Double
+ case l: Long => l: java.lang.Long
+ case l: Short => l: java.lang.Short
+ case l: Byte => l: java.lang.Byte
+ case s: Seq[_] => seqAsJavaList(s.map(wrap))
+ case m: Map[_,_] =>
+ mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) })
+ case null => null
+ }
+
+ def toInspector(dataType: DataType): ObjectInspector = dataType match {
+ case ArrayType(tpe) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe))
+ case MapType(keyType, valueType) =>
+ ObjectInspectorFactory.getStandardMapObjectInspector(
+ toInspector(keyType), toInspector(valueType))
+ case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector
+ case IntegerType => PrimitiveObjectInspectorFactory.javaIntObjectInspector
+ case DoubleType => PrimitiveObjectInspectorFactory.javaDoubleObjectInspector
+ case BooleanType => PrimitiveObjectInspectorFactory.javaBooleanObjectInspector
+ case LongType => PrimitiveObjectInspectorFactory.javaLongObjectInspector
+ case FloatType => PrimitiveObjectInspectorFactory.javaFloatObjectInspector
+ case ShortType => PrimitiveObjectInspectorFactory.javaShortObjectInspector
+ case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector
+ case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector
+ case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector
+ }
+
+ def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match {
+ case s: StructObjectInspector =>
+ StructType(s.getAllStructFieldRefs.map(f => {
+ types.StructField(
+ f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true)
+ }))
+ case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector))
+ case m: MapObjectInspector =>
+ MapType(
+ inspectorToDataType(m.getMapKeyObjectInspector),
+ inspectorToDataType(m.getMapValueObjectInspector))
+ case _: WritableStringObjectInspector => StringType
+ case _: JavaStringObjectInspector => StringType
+ case _: WritableIntObjectInspector => IntegerType
+ case _: JavaIntObjectInspector => IntegerType
+ case _: WritableDoubleObjectInspector => DoubleType
+ case _: JavaDoubleObjectInspector => DoubleType
+ case _: WritableBooleanObjectInspector => BooleanType
+ case _: JavaBooleanObjectInspector => BooleanType
+ case _: WritableLongObjectInspector => LongType
+ case _: JavaLongObjectInspector => LongType
+ case _: WritableShortObjectInspector => ShortType
+ case _: JavaShortObjectInspector => ShortType
+ case _: WritableByteObjectInspector => ByteType
+ case _: JavaByteObjectInspector => ByteType
+ }
+
+ implicit class typeInfoConversions(dt: DataType) {
+ import org.apache.hadoop.hive.serde2.typeinfo._
+ import TypeInfoFactory._
+
+ def toTypeInfo: TypeInfo = dt match {
+ case BinaryType => binaryTypeInfo
+ case BooleanType => booleanTypeInfo
+ case ByteType => byteTypeInfo
+ case DoubleType => doubleTypeInfo
+ case FloatType => floatTypeInfo
+ case IntegerType => intTypeInfo
+ case LongType => longTypeInfo
+ case ShortType => shortTypeInfo
+ case StringType => stringTypeInfo
+ case DecimalType => decimalTypeInfo
+ case NullType => voidTypeInfo
+ }
+ }
+}
+
+case class HiveGenericUdaf(
+ name: String,
+ children: Seq[Expression]) extends AggregateExpression
+ with HiveInspectors
+ with HiveFunctionFactory {
+
+ type UDFType = AbstractGenericUDAFResolver
+
+ protected lazy val resolver: AbstractGenericUDAFResolver = createFunction(name)
+
+ protected lazy val objectInspector = {
+ resolver.getEvaluator(children.map(_.dataType.toTypeInfo).toArray)
+ .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray)
+ }
+
+ protected lazy val inspectors = children.map(_.dataType).map(toInspector)
+
+ def dataType: DataType = inspectorToDataType(objectInspector)
+
+ def nullable: Boolean = true
+
+ def references: Set[Attribute] = children.map(_.references).flatten.toSet
+
+ override def toString = s"$nodeName#$name(${children.mkString(",")})"
+
+ def newInstance = new HiveUdafFunction(name, children, this)
+}
+
+/**
+ * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a
+ * [[catalyst.expressions.Generator Generator]]. Note that the semantics of Generators do not allow
+ * Generators to maintain state in between input rows. Thus UDTFs that rely on partitioning
+ * dependent operations like calls to `close()` before producing output will not operate the same as
+ * in Hive. However, in practice this should not affect compatibility for most sane UDTFs
+ * (e.g. explode or GenericUDTFParseUrlTuple).
+ *
+ * Operators that require maintaining state in between input rows should instead be implemented as
+ * user defined aggregations, which have clean semantics even in a partitioned execution.
+ */
+case class HiveGenericUdtf(
+ name: String,
+ aliasNames: Seq[String],
+ children: Seq[Expression])
+ extends Generator with HiveInspectors with HiveFunctionFactory {
+
+ override def references = children.flatMap(_.references).toSet
+
+ @transient
+ protected lazy val function: GenericUDTF = createFunction(name)
+
+ protected lazy val inputInspectors = children.map(_.dataType).map(toInspector)
+
+ protected lazy val outputInspectors = {
+ val structInspector = function.initialize(inputInspectors.toArray)
+ structInspector.getAllStructFieldRefs.map(_.getFieldObjectInspector)
+ }
+
+ protected lazy val outputDataTypes = outputInspectors.map(inspectorToDataType)
+
+ override protected def makeOutput() = {
+ // Use column names when given, otherwise c_1, c_2, ... c_n.
+ if (aliasNames.size == outputDataTypes.size) {
+ aliasNames.zip(outputDataTypes).map {
+ case (attrName, attrDataType) =>
+ AttributeReference(attrName, attrDataType, nullable = true)()
+ }
+ } else {
+ outputDataTypes.zipWithIndex.map {
+ case (attrDataType, i) =>
+ AttributeReference(s"c_$i", attrDataType, nullable = true)()
+ }
+ }
+ }
+
+ override def apply(input: Row): TraversableOnce[Row] = {
+ outputInspectors // Make sure initialized.
+
+ val inputProjection = new Projection(children)
+ val collector = new UDTFCollector
+ function.setCollector(collector)
+
+ val udtInput = inputProjection(input).map(wrap).toArray
+ function.process(udtInput)
+ collector.collectRows()
+ }
+
+ protected class UDTFCollector extends Collector {
+ var collected = new ArrayBuffer[Row]
+
+ override def collect(input: java.lang.Object) {
+ // We need to clone the input here because implementations of
+ // GenericUDTF reuse the same object. Luckily they are always an array, so
+ // it is easy to clone.
+ collected += new GenericRow(input.asInstanceOf[Array[_]].map(unwrap))
+ }
+
+ def collectRows() = {
+ val toCollect = collected
+ collected = new ArrayBuffer[Row]
+ toCollect
+ }
+ }
+
+ override def toString() = s"$nodeName#$name(${children.mkString(",")})"
+}
+
+case class HiveUdafFunction(
+ functionName: String,
+ exprs: Seq[Expression],
+ base: AggregateExpression)
+ extends AggregateFunction
+ with HiveInspectors
+ with HiveFunctionFactory {
+
+ def this() = this(null, null, null)
+
+ private val resolver = createFunction[AbstractGenericUDAFResolver](functionName)
+
+ private val inspectors = exprs.map(_.dataType).map(toInspector).toArray
+
+ private val function = resolver.getEvaluator(exprs.map(_.dataType.toTypeInfo).toArray)
+
+ private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors)
+
+ // Cast required to avoid type inference selecting a deprecated Hive API.
+ private val buffer =
+ function.getNewAggregationBuffer.asInstanceOf[GenericUDAFEvaluator.AbstractAggregationBuffer]
+
+ override def apply(input: Row): Any = unwrapData(function.evaluate(buffer), returnInspector)
+
+ @transient
+ val inputProjection = new Projection(exprs)
+
+ def update(input: Row): Unit = {
+ val inputs = inputProjection(input).asInstanceOf[Seq[AnyRef]].toArray
+ function.iterate(buffer, inputs)
+ }
+}
diff --git a/sql/hive/src/test/resources/log4j.properties b/sql/hive/src/test/resources/log4j.properties
new file mode 100644
index 0000000000..5e17e3b596
--- /dev/null
+++ b/sql/hive/src/test/resources/log4j.properties
@@ -0,0 +1,47 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Set everything to be logged to the file core/target/unit-tests.log
+log4j.rootLogger=DEBUG, CA, FA
+
+#Console Appender
+log4j.appender.CA=org.apache.log4j.ConsoleAppender
+log4j.appender.CA.layout=org.apache.log4j.PatternLayout
+log4j.appender.CA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c: %m%n
+log4j.appender.CA.Threshold = WARN
+
+
+#File Appender
+log4j.appender.FA=org.apache.log4j.FileAppender
+log4j.appender.FA.append=false
+log4j.appender.FA.file=target/unit-tests.log
+log4j.appender.FA.layout=org.apache.log4j.PatternLayout
+log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c{1}: %m%n
+
+# Set the logger level of File Appender to WARN
+log4j.appender.FA.Threshold = INFO
+
+# Some packages are noisy for no good reason.
+log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false
+log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF
+
+log4j.additivity.org.apache.hadoop.hive.metastore.RetryingHMSHandler=false
+log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=OFF
+
+log4j.additivity.hive.ql.metadata.Hive=false
+log4j.logger.hive.ql.metadata.Hive=OFF
+
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala
new file mode 100644
index 0000000000..4b45e69860
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+package execution
+
+import java.io.File
+
+/**
+ * A set of test cases based on the big-data-benchmark.
+ * https://amplab.cs.berkeley.edu/benchmark/
+ */
+class BigDataBenchmarkSuite extends HiveComparisonTest {
+ import TestHive._
+
+ val testDataDirectory = new File("target/big-data-benchmark-testdata")
+
+ val testTables = Seq(
+ TestTable(
+ "rankings",
+ s"""
+ |CREATE EXTERNAL TABLE rankings (
+ | pageURL STRING,
+ | pageRank INT,
+ | avgDuration INT)
+ | ROW FORMAT DELIMITED FIELDS TERMINATED BY ","
+ | STORED AS TEXTFILE LOCATION "${new File(testDataDirectory, "rankings").getCanonicalPath}"
+ """.stripMargin.cmd),
+ TestTable(
+ "scratch",
+ s"""
+ |CREATE EXTERNAL TABLE scratch (
+ | pageURL STRING,
+ | pageRank INT,
+ | avgDuration INT)
+ | ROW FORMAT DELIMITED FIELDS TERMINATED BY ","
+ | STORED AS TEXTFILE LOCATION "${new File(testDataDirectory, "scratch").getCanonicalPath}"
+ """.stripMargin.cmd),
+ TestTable(
+ "uservisits",
+ s"""
+ |CREATE EXTERNAL TABLE uservisits (
+ | sourceIP STRING,
+ | destURL STRING,
+ | visitDate STRING,
+ | adRevenue DOUBLE,
+ | userAgent STRING,
+ | countryCode STRING,
+ | languageCode STRING,
+ | searchWord STRING,
+ | duration INT)
+ | ROW FORMAT DELIMITED FIELDS TERMINATED BY ","
+ | STORED AS TEXTFILE LOCATION "${new File(testDataDirectory, "uservisits").getCanonicalPath}"
+ """.stripMargin.cmd),
+ TestTable(
+ "documents",
+ s"""
+ |CREATE EXTERNAL TABLE documents (line STRING)
+ |STORED AS TEXTFILE
+ |LOCATION "${new File(testDataDirectory, "crawl").getCanonicalPath}"
+ """.stripMargin.cmd))
+
+ testTables.foreach(registerTestTable)
+
+ if (!testDataDirectory.exists()) {
+ // TODO: Auto download the files on demand.
+ ignore("No data files found for BigDataBenchmark tests.") {}
+ } else {
+ createQueryTest("query1",
+ "SELECT pageURL, pageRank FROM rankings WHERE pageRank > 1")
+
+ createQueryTest("query2",
+ "SELECT SUBSTR(sourceIP, 1, 10), SUM(adRevenue) FROM uservisits GROUP BY SUBSTR(sourceIP, 1, 10)")
+
+ createQueryTest("query3",
+ """
+ |SELECT sourceIP,
+ | sum(adRevenue) as totalRevenue,
+ | avg(pageRank) as pageRank
+ |FROM
+ | rankings R JOIN
+ | (SELECT sourceIP, destURL, adRevenue
+ | FROM uservisits UV
+ | WHERE UV.visitDate > "1980-01-01"
+ | AND UV.visitDate < "1980-04-01")
+ | NUV ON (R.pageURL = NUV.destURL)
+ |GROUP BY sourceIP
+ |ORDER BY totalRevenue DESC
+ |LIMIT 1
+ """.stripMargin)
+
+ createQueryTest("query4",
+ """
+ |DROP TABLE IF EXISTS url_counts_partial;
+ |CREATE TABLE url_counts_partial AS
+ | SELECT TRANSFORM (line)
+ | USING 'python target/url_count.py' as (sourcePage,
+ | destPage, count) from documents;
+ |DROP TABLE IF EXISTS url_counts_total;
+ |CREATE TABLE url_counts_total AS
+ | SELECT SUM(count) AS totalCount, destpage
+ | FROM url_counts_partial GROUP BY destpage
+ |-- The following queries run, but generate different results in HIVE likely because the UDF is not deterministic
+ |-- given different input splits.
+ |-- SELECT CAST(SUM(count) AS INT) FROM url_counts_partial
+ |-- SELECT COUNT(*) FROM url_counts_partial
+ |-- SELECT * FROM url_counts_partial
+ |-- SELECT * FROM url_counts_total
+ """.stripMargin)
+ }
+} \ No newline at end of file
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala
new file mode 100644
index 0000000000..a12ab23946
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+package sql
+package hive
+package execution
+
+
+import org.scalatest.{FunSuite, BeforeAndAfterAll}
+
+class ConcurrentHiveSuite extends FunSuite with BeforeAndAfterAll {
+ ignore("multiple instances not supported") {
+ test("Multiple Hive Instances") {
+ (1 to 10).map { i =>
+ val ts =
+ new TestHiveContext(new SparkContext("local", s"TestSQLContext$i", new SparkConf()))
+ ts.executeSql("SHOW TABLES").toRdd.collect()
+ ts.executeSql("SELECT * FROM src").toRdd.collect()
+ ts.executeSql("SHOW TABLES").toRdd.collect()
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
new file mode 100644
index 0000000000..8a5b97b7a0
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
@@ -0,0 +1,379 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+package execution
+
+import java.io._
+import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen}
+
+import catalyst.plans.logical.{ExplainCommand, NativeCommand}
+import catalyst.plans._
+import catalyst.util._
+
+import org.apache.spark.sql.execution.Sort
+
+/**
+ * Allows the creations of tests that execute the same query against both hive
+ * and catalyst, comparing the results.
+ *
+ * The "golden" results from Hive are cached in an retrieved both from the classpath and
+ * [[answerCache]] to speed up testing.
+ *
+ * See the documentation of public vals in this class for information on how test execution can be
+ * configured using system properties.
+ */
+abstract class HiveComparisonTest extends FunSuite with BeforeAndAfterAll with GivenWhenThen with Logging {
+
+ /**
+ * When set, any cache files that result in test failures will be deleted. Used when the test
+ * harness or hive have been updated thus requiring new golden answers to be computed for some
+ * tests. Also prevents the classpath being used when looking for golden answers as these are
+ * usually stale.
+ */
+ val recomputeCache = System.getProperty("spark.hive.recomputeCache") != null
+
+ protected val shardRegEx = "(\\d+):(\\d+)".r
+ /**
+ * Allows multiple JVMs to be run in parallel, each responsible for portion of all test cases.
+ * Format `shardId:numShards`. Shard ids should be zero indexed. E.g. -Dspark.hive.testshard=0:4.
+ */
+ val shardInfo = Option(System.getProperty("spark.hive.shard")).map {
+ case shardRegEx(id, total) => (id.toInt, total.toInt)
+ }
+
+ protected val targetDir = new File("target")
+
+ /**
+ * When set, this comma separated list is defines directories that contain the names of test cases
+ * that should be skipped.
+ *
+ * For example when `-Dspark.hive.skiptests=passed,hiveFailed` is specified and test cases listed
+ * in [[passedDirectory]] or [[hiveFailedDirectory]] will be skipped.
+ */
+ val skipDirectories =
+ Option(System.getProperty("spark.hive.skiptests"))
+ .toSeq
+ .flatMap(_.split(","))
+ .map(name => new File(targetDir, s"$suiteName.$name"))
+
+ val runOnlyDirectories =
+ Option(System.getProperty("spark.hive.runonlytests"))
+ .toSeq
+ .flatMap(_.split(","))
+ .map(name => new File(targetDir, s"$suiteName.$name"))
+
+ /** The local directory with cached golden answer will be stored. */
+ protected val answerCache = new File("src/test/resources/golden")
+ if (!answerCache.exists) {
+ answerCache.mkdir()
+ }
+
+ /** The [[ClassLoader]] that contains test dependencies. Used to look for golden answers. */
+ protected val testClassLoader = this.getClass.getClassLoader
+
+ /** Directory containing a file for each test case that passes. */
+ val passedDirectory = new File(targetDir, s"$suiteName.passed")
+ if (!passedDirectory.exists()) {
+ passedDirectory.mkdir() // Not atomic!
+ }
+
+ /** Directory containing output of tests that fail to execute with Catalyst. */
+ val failedDirectory = new File(targetDir, s"$suiteName.failed")
+ if (!failedDirectory.exists()) {
+ failedDirectory.mkdir() // Not atomic!
+ }
+
+ /** Directory containing output of tests where catalyst produces the wrong answer. */
+ val wrongDirectory = new File(targetDir, s"$suiteName.wrong")
+ if (!wrongDirectory.exists()) {
+ wrongDirectory.mkdir() // Not atomic!
+ }
+
+ /** Directory containing output of tests where we fail to generate golden output with Hive. */
+ val hiveFailedDirectory = new File(targetDir, s"$suiteName.hiveFailed")
+ if (!hiveFailedDirectory.exists()) {
+ hiveFailedDirectory.mkdir() // Not atomic!
+ }
+
+ /** All directories that contain per-query output files */
+ val outputDirectories = Seq(
+ passedDirectory,
+ failedDirectory,
+ wrongDirectory,
+ hiveFailedDirectory)
+
+ protected val cacheDigest = java.security.MessageDigest.getInstance("MD5")
+ protected def getMd5(str: String): String = {
+ val digest = java.security.MessageDigest.getInstance("MD5")
+ digest.update(str.getBytes)
+ new java.math.BigInteger(1, digest.digest).toString(16)
+ }
+
+ protected def prepareAnswer(
+ hiveQuery: TestHive.type#SqlQueryExecution,
+ answer: Seq[String]): Seq[String] = {
+ val orderedAnswer = hiveQuery.logical match {
+ // Clean out non-deterministic time schema info.
+ case _: NativeCommand => answer.filterNot(nonDeterministicLine).filterNot(_ == "")
+ case _: ExplainCommand => answer
+ case _ =>
+ // TODO: Really we only care about the final total ordering here...
+ val isOrdered = hiveQuery.executedPlan.collect {
+ case s @ Sort(_, global, _) if global => s
+ }.nonEmpty
+ // If the query results aren't sorted, then sort them to ensure deterministic answers.
+ if (!isOrdered) answer.sorted else answer
+ }
+ orderedAnswer.map(cleanPaths)
+ }
+
+ // TODO: Instead of filtering we should clean to avoid accidentally ignoring actual results.
+ lazy val nonDeterministicLineIndicators = Seq(
+ "CreateTime",
+ "transient_lastDdlTime",
+ "grantTime",
+ "lastUpdateTime",
+ "last_modified_time",
+ "Owner:",
+ // The following are hive specific schema parameters which we do not need to match exactly.
+ "numFiles",
+ "numRows",
+ "rawDataSize",
+ "totalSize",
+ "totalNumberFiles",
+ "maxFileSize",
+ "minFileSize"
+ )
+ protected def nonDeterministicLine(line: String) =
+ nonDeterministicLineIndicators.map(line contains _).reduceLeft(_||_)
+
+ /**
+ * Removes non-deterministic paths from `str` so cached answers will compare correctly.
+ */
+ protected def cleanPaths(str: String): String = {
+ str.replaceAll("file:\\/.*\\/", "<PATH>")
+ }
+
+ val installHooksCommand = "(?i)SET.*hooks".r
+ def createQueryTest(testCaseName: String, sql: String) {
+ // If test sharding is enable, skip tests that are not in the correct shard.
+ shardInfo.foreach {
+ case (shardId, numShards) if testCaseName.hashCode % numShards != shardId => return
+ case (shardId, _) => logger.debug(s"Shard $shardId includes test '$testCaseName'")
+ }
+
+ // Skip tests found in directories specified by user.
+ skipDirectories
+ .map(new File(_, testCaseName))
+ .filter(_.exists)
+ .foreach(_ => return)
+
+ // If runonlytests is set, skip this test unless we find a file in one of the specified
+ // directories.
+ val runIndicators =
+ runOnlyDirectories
+ .map(new File(_, testCaseName))
+ .filter(_.exists)
+ if (runOnlyDirectories.nonEmpty && runIndicators.isEmpty) {
+ logger.debug(
+ s"Skipping test '$testCaseName' not found in ${runOnlyDirectories.map(_.getCanonicalPath)}")
+ return
+ }
+
+ test(testCaseName) {
+ logger.debug(s"=== HIVE TEST: $testCaseName ===")
+
+ // Clear old output for this testcase.
+ outputDirectories.map(new File(_, testCaseName)).filter(_.exists()).foreach(_.delete())
+
+ val allQueries = sql.split("(?<=[^\\\\]);").map(_.trim).filterNot(q => q == "").toSeq
+
+ // TODO: DOCUMENT UNSUPPORTED
+ val queryList =
+ allQueries
+ // In hive, setting the hive.outerjoin.supports.filters flag to "false" essentially tells
+ // the system to return the wrong answer. Since we have no intention of mirroring their
+ // previously broken behavior we simply filter out changes to this setting.
+ .filterNot(_ contains "hive.outerjoin.supports.filters")
+
+ if (allQueries != queryList)
+ logger.warn(s"Simplifications made on unsupported operations for test $testCaseName")
+
+ lazy val consoleTestCase = {
+ val quotes = "\"\"\""
+ queryList.zipWithIndex.map {
+ case (query, i) =>
+ s"""
+ |val q$i = $quotes$query$quotes.q
+ |q$i.stringResult()
+ """.stripMargin
+ }.mkString("\n== Console version of this test ==\n", "\n", "\n")
+ }
+
+ try {
+ // MINOR HACK: You must run a query before calling reset the first time.
+ TestHive.sql("SHOW TABLES")
+ TestHive.reset()
+
+ val hiveCacheFiles = queryList.zipWithIndex.map {
+ case (queryString, i) =>
+ val cachedAnswerName = s"$testCaseName-$i-${getMd5(queryString)}"
+ new File(answerCache, cachedAnswerName)
+ }
+
+ val hiveCachedResults = hiveCacheFiles.flatMap { cachedAnswerFile =>
+ logger.debug(s"Looking for cached answer file $cachedAnswerFile.")
+ if (cachedAnswerFile.exists) {
+ Some(fileToString(cachedAnswerFile))
+ } else {
+ logger.debug(s"File $cachedAnswerFile not found")
+ None
+ }
+ }.map {
+ case "" => Nil
+ case "\n" => Seq("")
+ case other => other.split("\n").toSeq
+ }
+
+ val hiveResults: Seq[Seq[String]] =
+ if (hiveCachedResults.size == queryList.size) {
+ logger.info(s"Using answer cache for test: $testCaseName")
+ hiveCachedResults
+ } else {
+
+ val hiveQueries = queryList.map(new TestHive.SqlQueryExecution(_))
+ // Make sure we can at least parse everything before attempting hive execution.
+ hiveQueries.foreach(_.logical)
+ val computedResults = (queryList.zipWithIndex, hiveQueries, hiveCacheFiles).zipped.map {
+ case ((queryString, i), hiveQuery, cachedAnswerFile)=>
+ try {
+ // Hooks often break the harness and don't really affect our test anyway, don't
+ // even try running them.
+ if (installHooksCommand.findAllMatchIn(queryString).nonEmpty)
+ sys.error("hive exec hooks not supported for tests.")
+
+ logger.warn(s"Running query ${i+1}/${queryList.size} with hive.")
+ // Analyze the query with catalyst to ensure test tables are loaded.
+ val answer = hiveQuery.analyzed match {
+ case _: ExplainCommand => Nil // No need to execute EXPLAIN queries as we don't check the output.
+ case _ => TestHive.runSqlHive(queryString)
+ }
+
+ // We need to add a new line to non-empty answers so we can differentiate Seq()
+ // from Seq("").
+ stringToFile(
+ cachedAnswerFile, answer.mkString("\n") + (if (answer.nonEmpty) "\n" else ""))
+ answer
+ } catch {
+ case e: Exception =>
+ val errorMessage =
+ s"""
+ |Failed to generate golden answer for query:
+ |Error: ${e.getMessage}
+ |${stackTraceToString(e)}
+ |$queryString
+ |$consoleTestCase
+ """.stripMargin
+ stringToFile(
+ new File(hiveFailedDirectory, testCaseName),
+ errorMessage + consoleTestCase)
+ fail(errorMessage)
+ }
+ }.toSeq
+ TestHive.reset()
+
+ computedResults
+ }
+
+ // Run w/ catalyst
+ val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) =>
+ val query = new TestHive.SqlQueryExecution(queryString)
+ try { (query, prepareAnswer(query, query.stringResult())) } catch {
+ case e: Exception =>
+ val errorMessage =
+ s"""
+ |Failed to execute query using catalyst:
+ |Error: ${e.getMessage}
+ |${stackTraceToString(e)}
+ |$query
+ |== HIVE - ${hive.size} row(s) ==
+ |${hive.mkString("\n")}
+ |
+ |$consoleTestCase
+ """.stripMargin
+ stringToFile(new File(failedDirectory, testCaseName), errorMessage + consoleTestCase)
+ fail(errorMessage)
+ }
+ }.toSeq
+
+ (queryList, hiveResults, catalystResults).zipped.foreach {
+ case (query, hive, (hiveQuery, catalyst)) =>
+ // Check that the results match unless its an EXPLAIN query.
+ val preparedHive = prepareAnswer(hiveQuery,hive)
+
+ if ((!hiveQuery.logical.isInstanceOf[ExplainCommand]) && preparedHive != catalyst) {
+
+ val hivePrintOut = s"== HIVE - ${hive.size} row(s) ==" +: preparedHive
+ val catalystPrintOut = s"== CATALYST - ${catalyst.size} row(s) ==" +: catalyst
+
+ val resultComparison = sideBySide(hivePrintOut, catalystPrintOut).mkString("\n")
+
+ if (recomputeCache) {
+ logger.warn(s"Clearing cache files for failed test $testCaseName")
+ hiveCacheFiles.foreach(_.delete())
+ }
+
+ val errorMessage =
+ s"""
+ |Results do not match for $testCaseName:
+ |$hiveQuery\n${hiveQuery.analyzed.output.map(_.name).mkString("\t")}
+ |$resultComparison
+ """.stripMargin
+
+ stringToFile(new File(wrongDirectory, testCaseName), errorMessage + consoleTestCase)
+ fail(errorMessage)
+ }
+ }
+
+ // Touch passed file.
+ new FileOutputStream(new File(passedDirectory, testCaseName)).close()
+ } catch {
+ case tf: org.scalatest.exceptions.TestFailedException => throw tf
+ case originalException: Exception =>
+ if (System.getProperty("spark.hive.canarytest") != null) {
+ // When we encounter an error we check to see if the environment is still okay by running a simple query.
+ // If this fails then we halt testing since something must have gone seriously wrong.
+ try {
+ new TestHive.SqlQueryExecution("SELECT key FROM src").stringResult()
+ TestHive.runSqlHive("SELECT key FROM src")
+ } catch {
+ case e: Exception =>
+ logger.error(s"FATAL ERROR: Canary query threw $e This implies that the testing environment has likely been corrupted.")
+ // The testing setup traps exits so wait here for a long time so the developer can see when things started
+ // to go wrong.
+ Thread.sleep(1000000)
+ }
+ }
+
+ // If the canary query didn't fail then the environment is still okay, so just throw the original exception.
+ throw originalException
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
new file mode 100644
index 0000000000..d010023f78
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -0,0 +1,708 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+package execution
+
+
+import java.io._
+
+import util._
+
+/**
+ * Runs the test cases that are included in the hive distribution.
+ */
+class HiveCompatibilitySuite extends HiveQueryFileTest {
+ // TODO: bundle in jar files... get from classpath
+ lazy val hiveQueryDir = TestHive.getHiveFile("ql/src/test/queries/clientpositive")
+ def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f)
+
+ /** A list of tests deemed out of scope currently and thus completely disregarded. */
+ override def blackList = Seq(
+ // These tests use hooks that are not on the classpath and thus break all subsequent execution.
+ "hook_order",
+ "hook_context",
+ "mapjoin_hook",
+ "multi_sahooks",
+ "overridden_confs",
+ "query_properties",
+ "sample10",
+ "updateAccessTime",
+ "index_compact_binary_search",
+ "bucket_num_reducers",
+ "column_access_stats",
+ "concatenate_inherit_table_location",
+
+ // Setting a default property does not seem to get reset and thus changes the answer for many
+ // subsequent tests.
+ "create_default_prop",
+
+ // User/machine specific test answers, breaks the caching mechanism.
+ "authorization_3",
+ "authorization_5",
+ "keyword_1",
+ "misc_json",
+ "create_like_tbl_props",
+ "load_overwrite",
+ "alter_table_serde2",
+ "alter_table_not_sorted",
+ "alter_skewed_table",
+ "alter_partition_clusterby_sortby",
+ "alter_merge",
+ "alter_concatenate_indexed_table",
+ "protectmode2",
+ "describe_table",
+ "describe_comment_nonascii",
+ "udf5",
+ "udf_java_method",
+
+ // Weird DDL differences result in failures on jenkins.
+ "create_like2",
+ "create_view_translate",
+ "partitions_json",
+
+ // Timezone specific test answers.
+ "udf_unix_timestamp",
+ "udf_to_unix_timestamp",
+
+ // Cant run without local map/reduce.
+ "index_auto_update",
+ "index_auto_self_join",
+ "index_stale.*",
+ "type_cast_1",
+ "index_compression",
+ "index_bitmap_compression",
+ "index_auto_multiple",
+ "index_auto_mult_tables_compact",
+ "index_auto_mult_tables",
+ "index_auto_file_format",
+ "index_auth",
+ "index_auto_empty",
+ "index_auto_partitioned",
+ "index_auto_unused",
+ "index_bitmap_auto_partitioned",
+ "ql_rewrite_gbtoidx",
+ "stats1.*",
+ "stats20",
+ "alter_merge_stats",
+
+ // Hive seems to think 1.0 > NaN = true && 1.0 < NaN = false... which is wrong.
+ // http://stackoverflow.com/a/1573715
+ "ops_comparison",
+
+ // Tests that seems to never complete on hive...
+ "skewjoin",
+ "database",
+
+ // These tests fail and and exit the JVM.
+ "auto_join18_multi_distinct",
+ "join18_multi_distinct",
+ "input44",
+ "input42",
+ "input_dfs",
+ "metadata_export_drop",
+ "repair",
+
+ // Uses a serde that isn't on the classpath... breaks other tests.
+ "bucketizedhiveinputformat",
+
+ // Avro tests seem to change the output format permanently thus breaking the answer cache, until
+ // we figure out why this is the case let just ignore all of avro related tests.
+ ".*avro.*",
+
+ // Unique joins are weird and will require a lot of hacks (see comments in hive parser).
+ "uniquejoin",
+
+ // Hive seems to get the wrong answer on some outer joins. MySQL agrees with catalyst.
+ "auto_join29",
+
+ // No support for multi-alias i.e. udf as (e1, e2, e3).
+ "allcolref_in_udf",
+
+ // No support for TestSerDe (not published afaik)
+ "alter1",
+ "input16",
+
+ // No support for unpublished test udfs.
+ "autogen_colalias",
+
+ // Hive does not support buckets.
+ ".*bucket.*",
+
+ // No window support yet
+ ".*window.*",
+
+ // Fails in hive with authorization errors.
+ "alter_rename_partition_authorization",
+ "authorization.*",
+
+ // Hadoop version specific tests
+ "archive_corrupt",
+
+ // No support for case sensitivity is resolution using hive properties atm.
+ "case_sensitivity"
+ )
+
+ /**
+ * The set of tests that are believed to be working in catalyst. Tests not on whiteList or
+ * blacklist are implicitly marked as ignored.
+ */
+ override def whiteList = Seq(
+ "add_part_exist",
+ "add_partition_no_whitelist",
+ "add_partition_with_whitelist",
+ "alias_casted_column",
+ "alter2",
+ "alter4",
+ "alter5",
+ "alter_index",
+ "alter_merge_2",
+ "alter_partition_format_loc",
+ "alter_partition_protect_mode",
+ "alter_partition_with_whitelist",
+ "alter_table_serde",
+ "alter_varchar2",
+ "alter_view_as_select",
+ "ambiguous_col",
+ "auto_join0",
+ "auto_join1",
+ "auto_join10",
+ "auto_join11",
+ "auto_join12",
+ "auto_join13",
+ "auto_join14",
+ "auto_join14_hadoop20",
+ "auto_join15",
+ "auto_join17",
+ "auto_join18",
+ "auto_join19",
+ "auto_join2",
+ "auto_join20",
+ "auto_join21",
+ "auto_join22",
+ "auto_join23",
+ "auto_join24",
+ "auto_join25",
+ "auto_join26",
+ "auto_join27",
+ "auto_join28",
+ "auto_join3",
+ "auto_join30",
+ "auto_join31",
+ "auto_join32",
+ "auto_join4",
+ "auto_join5",
+ "auto_join6",
+ "auto_join7",
+ "auto_join8",
+ "auto_join9",
+ "auto_join_filters",
+ "auto_join_nulls",
+ "auto_join_reordering_values",
+ "auto_sortmerge_join_1",
+ "auto_sortmerge_join_10",
+ "auto_sortmerge_join_11",
+ "auto_sortmerge_join_12",
+ "auto_sortmerge_join_15",
+ "auto_sortmerge_join_2",
+ "auto_sortmerge_join_3",
+ "auto_sortmerge_join_4",
+ "auto_sortmerge_join_5",
+ "auto_sortmerge_join_6",
+ "auto_sortmerge_join_7",
+ "auto_sortmerge_join_8",
+ "auto_sortmerge_join_9",
+ "binary_constant",
+ "binarysortable_1",
+ "combine1",
+ "compute_stats_binary",
+ "compute_stats_boolean",
+ "compute_stats_double",
+ "compute_stats_table",
+ "compute_stats_long",
+ "compute_stats_string",
+ "convert_enum_to_string",
+ "correlationoptimizer11",
+ "correlationoptimizer15",
+ "correlationoptimizer2",
+ "correlationoptimizer3",
+ "correlationoptimizer4",
+ "correlationoptimizer6",
+ "correlationoptimizer7",
+ "correlationoptimizer8",
+ "count",
+ "create_like_view",
+ "create_nested_type",
+ "create_skewed_table1",
+ "create_struct_table",
+ "ct_case_insensitive",
+ "database_location",
+ "database_properties",
+ "decimal_join",
+ "default_partition_name",
+ "delimiter",
+ "desc_non_existent_tbl",
+ "describe_comment_indent",
+ "describe_database_json",
+ "describe_pretty",
+ "describe_syntax",
+ "describe_table_json",
+ "diff_part_input_formats",
+ "disable_file_format_check",
+ "drop_function",
+ "drop_index",
+ "drop_partitions_filter",
+ "drop_partitions_filter2",
+ "drop_partitions_filter3",
+ "drop_partitions_ignore_protection",
+ "drop_table",
+ "drop_table2",
+ "drop_view",
+ "escape_clusterby1",
+ "escape_distributeby1",
+ "escape_orderby1",
+ "escape_sortby1",
+ "fetch_aggregation",
+ "filter_join_breaktask",
+ "filter_join_breaktask2",
+ "groupby1",
+ "groupby11",
+ "groupby1_map",
+ "groupby1_map_nomap",
+ "groupby1_map_skew",
+ "groupby1_noskew",
+ "groupby4",
+ "groupby4_map",
+ "groupby4_map_skew",
+ "groupby4_noskew",
+ "groupby5",
+ "groupby5_map",
+ "groupby5_map_skew",
+ "groupby5_noskew",
+ "groupby6",
+ "groupby6_map",
+ "groupby6_map_skew",
+ "groupby6_noskew",
+ "groupby7",
+ "groupby7_map",
+ "groupby7_map_multi_single_reducer",
+ "groupby7_map_skew",
+ "groupby7_noskew",
+ "groupby8_map",
+ "groupby8_map_skew",
+ "groupby8_noskew",
+ "groupby_distinct_samekey",
+ "groupby_multi_single_reducer2",
+ "groupby_mutli_insert_common_distinct",
+ "groupby_neg_float",
+ "groupby_sort_10",
+ "groupby_sort_6",
+ "groupby_sort_8",
+ "groupby_sort_test_1",
+ "implicit_cast1",
+ "innerjoin",
+ "inoutdriver",
+ "input",
+ "input0",
+ "input11",
+ "input11_limit",
+ "input12",
+ "input12_hadoop20",
+ "input19",
+ "input1_limit",
+ "input22",
+ "input23",
+ "input24",
+ "input25",
+ "input26",
+ "input28",
+ "input2_limit",
+ "input40",
+ "input41",
+ "input4_cb_delim",
+ "input6",
+ "input7",
+ "input8",
+ "input9",
+ "input_limit",
+ "input_part0",
+ "input_part1",
+ "input_part10",
+ "input_part10_win",
+ "input_part2",
+ "input_part3",
+ "input_part4",
+ "input_part5",
+ "input_part6",
+ "input_part7",
+ "input_part8",
+ "input_part9",
+ "inputddl4",
+ "inputddl7",
+ "inputddl8",
+ "insert_compressed",
+ "join0",
+ "join1",
+ "join10",
+ "join11",
+ "join12",
+ "join13",
+ "join14",
+ "join14_hadoop20",
+ "join15",
+ "join16",
+ "join17",
+ "join18",
+ "join19",
+ "join2",
+ "join20",
+ "join21",
+ "join22",
+ "join23",
+ "join24",
+ "join25",
+ "join26",
+ "join27",
+ "join28",
+ "join29",
+ "join3",
+ "join30",
+ "join31",
+ "join32",
+ "join33",
+ "join34",
+ "join35",
+ "join36",
+ "join37",
+ "join38",
+ "join39",
+ "join4",
+ "join40",
+ "join41",
+ "join5",
+ "join6",
+ "join7",
+ "join8",
+ "join9",
+ "join_1to1",
+ "join_array",
+ "join_casesensitive",
+ "join_empty",
+ "join_filters",
+ "join_hive_626",
+ "join_nulls",
+ "join_reorder2",
+ "join_reorder3",
+ "join_reorder4",
+ "join_star",
+ "join_view",
+ "lateral_view_cp",
+ "lateral_view_ppd",
+ "lineage1",
+ "literal_double",
+ "literal_ints",
+ "literal_string",
+ "load_dyn_part7",
+ "load_file_with_space_in_the_name",
+ "louter_join_ppr",
+ "mapjoin_distinct",
+ "mapjoin_mapjoin",
+ "mapjoin_subquery",
+ "mapjoin_subquery2",
+ "mapjoin_test_outer",
+ "mapreduce3",
+ "mapreduce7",
+ "merge1",
+ "merge2",
+ "mergejoins",
+ "mergejoins_mixed",
+ "multiMapJoin1",
+ "multiMapJoin2",
+ "multi_join_union",
+ "multigroupby_singlemr",
+ "noalias_subq1",
+ "nomore_ambiguous_table_col",
+ "nonblock_op_deduplicate",
+ "notable_alias1",
+ "notable_alias2",
+ "nullgroup",
+ "nullgroup2",
+ "nullgroup3",
+ "nullgroup4",
+ "nullgroup4_multi_distinct",
+ "nullgroup5",
+ "nullinput",
+ "nullinput2",
+ "nullscript",
+ "optional_outer",
+ "order",
+ "order2",
+ "outer_join_ppr",
+ "part_inherit_tbl_props",
+ "part_inherit_tbl_props_empty",
+ "part_inherit_tbl_props_with_star",
+ "partition_schema1",
+ "partition_varchar1",
+ "plan_json",
+ "ppd1",
+ "ppd_constant_where",
+ "ppd_gby",
+ "ppd_gby2",
+ "ppd_gby_join",
+ "ppd_join",
+ "ppd_join2",
+ "ppd_join3",
+ "ppd_join_filter",
+ "ppd_outer_join1",
+ "ppd_outer_join2",
+ "ppd_outer_join3",
+ "ppd_outer_join4",
+ "ppd_outer_join5",
+ "ppd_random",
+ "ppd_repeated_alias",
+ "ppd_udf_col",
+ "ppd_union",
+ "ppr_allchildsarenull",
+ "ppr_pushdown",
+ "ppr_pushdown2",
+ "ppr_pushdown3",
+ "progress_1",
+ "protectmode",
+ "push_or",
+ "query_with_semi",
+ "quote1",
+ "quote2",
+ "reduce_deduplicate_exclude_join",
+ "rename_column",
+ "router_join_ppr",
+ "select_as_omitted",
+ "select_unquote_and",
+ "select_unquote_not",
+ "select_unquote_or",
+ "serde_reported_schema",
+ "set_variable_sub",
+ "show_describe_func_quotes",
+ "show_functions",
+ "show_partitions",
+ "skewjoinopt13",
+ "skewjoinopt18",
+ "skewjoinopt9",
+ "smb_mapjoin_1",
+ "smb_mapjoin_10",
+ "smb_mapjoin_13",
+ "smb_mapjoin_14",
+ "smb_mapjoin_15",
+ "smb_mapjoin_16",
+ "smb_mapjoin_17",
+ "smb_mapjoin_2",
+ "smb_mapjoin_21",
+ "smb_mapjoin_25",
+ "smb_mapjoin_3",
+ "smb_mapjoin_4",
+ "smb_mapjoin_5",
+ "smb_mapjoin_8",
+ "sort",
+ "sort_merge_join_desc_1",
+ "sort_merge_join_desc_2",
+ "sort_merge_join_desc_3",
+ "sort_merge_join_desc_4",
+ "sort_merge_join_desc_5",
+ "sort_merge_join_desc_6",
+ "sort_merge_join_desc_7",
+ "stats0",
+ "stats_empty_partition",
+ "subq2",
+ "tablename_with_select",
+ "touch",
+ "type_widening",
+ "udaf_collect_set",
+ "udaf_corr",
+ "udaf_covar_pop",
+ "udaf_covar_samp",
+ "udf2",
+ "udf6",
+ "udf9",
+ "udf_10_trims",
+ "udf_E",
+ "udf_PI",
+ "udf_abs",
+ "udf_acos",
+ "udf_add",
+ "udf_array",
+ "udf_array_contains",
+ "udf_ascii",
+ "udf_asin",
+ "udf_atan",
+ "udf_avg",
+ "udf_bigint",
+ "udf_bin",
+ "udf_bitmap_and",
+ "udf_bitmap_empty",
+ "udf_bitmap_or",
+ "udf_bitwise_and",
+ "udf_bitwise_not",
+ "udf_bitwise_or",
+ "udf_bitwise_xor",
+ "udf_boolean",
+ "udf_case",
+ "udf_ceil",
+ "udf_ceiling",
+ "udf_concat",
+ "udf_concat_insert2",
+ "udf_concat_ws",
+ "udf_conv",
+ "udf_cos",
+ "udf_count",
+ "udf_date_add",
+ "udf_date_sub",
+ "udf_datediff",
+ "udf_day",
+ "udf_dayofmonth",
+ "udf_degrees",
+ "udf_div",
+ "udf_double",
+ "udf_exp",
+ "udf_field",
+ "udf_find_in_set",
+ "udf_float",
+ "udf_floor",
+ "udf_format_number",
+ "udf_from_unixtime",
+ "udf_greaterthan",
+ "udf_greaterthanorequal",
+ "udf_hex",
+ "udf_if",
+ "udf_index",
+ "udf_int",
+ "udf_isnotnull",
+ "udf_isnull",
+ "udf_java_method",
+ "udf_lcase",
+ "udf_length",
+ "udf_lessthan",
+ "udf_lessthanorequal",
+ "udf_like",
+ "udf_ln",
+ "udf_log",
+ "udf_log10",
+ "udf_log2",
+ "udf_lower",
+ "udf_lpad",
+ "udf_ltrim",
+ "udf_map",
+ "udf_minute",
+ "udf_modulo",
+ "udf_month",
+ "udf_negative",
+ "udf_not",
+ "udf_notequal",
+ "udf_notop",
+ "udf_nvl",
+ "udf_or",
+ "udf_parse_url",
+ "udf_positive",
+ "udf_pow",
+ "udf_power",
+ "udf_radians",
+ "udf_rand",
+ "udf_regexp",
+ "udf_regexp_extract",
+ "udf_regexp_replace",
+ "udf_repeat",
+ "udf_rlike",
+ "udf_round",
+ "udf_round_3",
+ "udf_rpad",
+ "udf_rtrim",
+ "udf_second",
+ "udf_sign",
+ "udf_sin",
+ "udf_smallint",
+ "udf_space",
+ "udf_sqrt",
+ "udf_std",
+ "udf_stddev",
+ "udf_stddev_pop",
+ "udf_stddev_samp",
+ "udf_string",
+ "udf_substring",
+ "udf_subtract",
+ "udf_sum",
+ "udf_tan",
+ "udf_tinyint",
+ "udf_to_byte",
+ "udf_to_date",
+ "udf_to_double",
+ "udf_to_float",
+ "udf_to_long",
+ "udf_to_short",
+ "udf_translate",
+ "udf_trim",
+ "udf_ucase",
+ "udf_upper",
+ "udf_var_pop",
+ "udf_var_samp",
+ "udf_variance",
+ "udf_weekofyear",
+ "udf_when",
+ "udf_xpath",
+ "udf_xpath_boolean",
+ "udf_xpath_double",
+ "udf_xpath_float",
+ "udf_xpath_int",
+ "udf_xpath_long",
+ "udf_xpath_short",
+ "udf_xpath_string",
+ "unicode_notation",
+ "union10",
+ "union11",
+ "union13",
+ "union14",
+ "union15",
+ "union16",
+ "union17",
+ "union18",
+ "union19",
+ "union2",
+ "union20",
+ "union22",
+ "union23",
+ "union24",
+ "union26",
+ "union27",
+ "union28",
+ "union29",
+ "union30",
+ "union31",
+ "union34",
+ "union4",
+ "union5",
+ "union6",
+ "union7",
+ "union8",
+ "union9",
+ "union_lateralview",
+ "union_ppr",
+ "union_remove_3",
+ "union_remove_6",
+ "union_script",
+ "varchar_2",
+ "varchar_join1",
+ "varchar_union1"
+ )
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala
new file mode 100644
index 0000000000..f0a4ec3c02
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+package execution
+
+import java.io._
+
+import catalyst.util._
+
+/**
+ * A framework for running the query tests that are listed as a set of text files.
+ *
+ * TestSuites that derive from this class must provide a map of testCaseName -> testCaseFiles that should be included.
+ * Additionally, there is support for whitelisting and blacklisting tests as development progresses.
+ */
+abstract class HiveQueryFileTest extends HiveComparisonTest {
+ /** A list of tests deemed out of scope and thus completely disregarded */
+ def blackList: Seq[String] = Nil
+
+ /**
+ * The set of tests that are believed to be working in catalyst. Tests not in whiteList
+ * blacklist are implicitly marked as ignored.
+ */
+ def whiteList: Seq[String] = ".*" :: Nil
+
+ def testCases: Seq[(String, File)]
+
+ val runAll =
+ !(System.getProperty("spark.hive.alltests") == null) ||
+ runOnlyDirectories.nonEmpty ||
+ skipDirectories.nonEmpty
+
+ val whiteListProperty = "spark.hive.whitelist"
+ // Allow the whiteList to be overridden by a system property
+ val realWhiteList =
+ Option(System.getProperty(whiteListProperty)).map(_.split(",").toSeq).getOrElse(whiteList)
+
+ // Go through all the test cases and add them to scala test.
+ testCases.sorted.foreach {
+ case (testCaseName, testCaseFile) =>
+ if (blackList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_)) {
+ logger.debug(s"Blacklisted test skipped $testCaseName")
+ } else if (realWhiteList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_) || runAll) {
+ // Build a test case and submit it to scala test framework...
+ val queriesString = fileToString(testCaseFile)
+ createQueryTest(testCaseName, queriesString)
+ } else {
+ // Only output warnings for the built in whitelist as this clutters the output when the user
+ // trying to execute a single test from the commandline.
+ if(System.getProperty(whiteListProperty) == null && !runAll)
+ ignore(testCaseName) {}
+ }
+ }
+} \ No newline at end of file
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
new file mode 100644
index 0000000000..28a5d260b3
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+package execution
+
+
+/**
+ * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution.
+ */
+class HiveQuerySuite extends HiveComparisonTest {
+ import TestHive._
+
+ createQueryTest("Simple Average",
+ "SELECT AVG(key) FROM src")
+
+ createQueryTest("Simple Average + 1",
+ "SELECT AVG(key) + 1.0 FROM src")
+
+ createQueryTest("Simple Average + 1 with group",
+ "SELECT AVG(key) + 1.0, value FROM src group by value")
+
+ createQueryTest("string literal",
+ "SELECT 'test' FROM src")
+
+ createQueryTest("Escape sequences",
+ """SELECT key, '\\\t\\' FROM src WHERE key = 86""")
+
+ createQueryTest("IgnoreExplain",
+ """EXPLAIN SELECT key FROM src""")
+
+ createQueryTest("trivial join where clause",
+ "SELECT * FROM src a JOIN src b WHERE a.key = b.key")
+
+ createQueryTest("trivial join ON clause",
+ "SELECT * FROM src a JOIN src b ON a.key = b.key")
+
+ createQueryTest("small.cartesian",
+ "SELECT a.key, b.key FROM (SELECT key FROM src WHERE key < 1) a JOIN (SELECT key FROM src WHERE key = 2) b")
+
+ createQueryTest("length.udf",
+ "SELECT length(\"test\") FROM src LIMIT 1")
+
+ ignore("partitioned table scan") {
+ createQueryTest("partitioned table scan",
+ "SELECT ds, hr, key, value FROM srcpart")
+ }
+
+ createQueryTest("hash",
+ "SELECT hash('test') FROM src LIMIT 1")
+
+ createQueryTest("create table as",
+ """
+ |CREATE TABLE createdtable AS SELECT * FROM src;
+ |SELECT * FROM createdtable
+ """.stripMargin)
+
+ createQueryTest("create table as with db name",
+ """
+ |CREATE DATABASE IF NOT EXISTS testdb;
+ |CREATE TABLE testdb.createdtable AS SELECT * FROM default.src;
+ |SELECT * FROM testdb.createdtable;
+ |DROP DATABASE IF EXISTS testdb CASCADE
+ """.stripMargin)
+
+ createQueryTest("insert table with db name",
+ """
+ |CREATE DATABASE IF NOT EXISTS testdb;
+ |CREATE TABLE testdb.createdtable like default.src;
+ |INSERT INTO TABLE testdb.createdtable SELECT * FROM default.src;
+ |SELECT * FROM testdb.createdtable;
+ |DROP DATABASE IF EXISTS testdb CASCADE
+ """.stripMargin)
+
+ createQueryTest("insert into and insert overwrite",
+ """
+ |CREATE TABLE createdtable like src;
+ |INSERT INTO TABLE createdtable SELECT * FROM src;
+ |INSERT INTO TABLE createdtable SELECT * FROM src1;
+ |SELECT * FROM createdtable;
+ |INSERT OVERWRITE TABLE createdtable SELECT * FROM src WHERE key = 86;
+ |SELECT * FROM createdtable;
+ """.stripMargin)
+
+ createQueryTest("transform",
+ "SELECT TRANSFORM (key) USING 'cat' AS (tKey) FROM src")
+
+ createQueryTest("LIKE",
+ "SELECT * FROM src WHERE value LIKE '%1%'")
+
+ createQueryTest("DISTINCT",
+ "SELECT DISTINCT key, value FROM src")
+
+ ignore("empty aggregate input") {
+ createQueryTest("empty aggregate input",
+ "SELECT SUM(key) FROM (SELECT * FROM src LIMIT 0) a")
+ }
+
+ createQueryTest("lateral view1",
+ "SELECT tbl.* FROM src LATERAL VIEW explode(array(1,2)) tbl as a")
+
+ createQueryTest("lateral view2",
+ "SELECT * FROM src LATERAL VIEW explode(array(1,2)) tbl")
+
+
+ createQueryTest("lateral view3",
+ "FROM src SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX")
+
+ createQueryTest("lateral view4",
+ """
+ |create table src_lv1 (key string, value string);
+ |create table src_lv2 (key string, value string);
+ |
+ |FROM src
+ |insert overwrite table src_lv1 SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX
+ |insert overwrite table src_lv2 SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX
+ """.stripMargin)
+
+ createQueryTest("lateral view5",
+ "FROM src SELECT explode(array(key+3, key+4))")
+
+ createQueryTest("lateral view6",
+ "SELECT * FROM src LATERAL VIEW explode(map(key+3,key+4)) D as k, v")
+
+ test("sampling") {
+ sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s")
+ }
+
+} \ No newline at end of file
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
new file mode 100644
index 0000000000..0dd79faa15
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+package execution
+
+/**
+ * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution.
+ */
+class HiveResolutionSuite extends HiveComparisonTest {
+ import TestHive._
+
+ createQueryTest("table.attr",
+ "SELECT src.key FROM src ORDER BY key LIMIT 1")
+
+ createQueryTest("database.table",
+ "SELECT key FROM default.src ORDER BY key LIMIT 1")
+
+ createQueryTest("database.table table.attr",
+ "SELECT src.key FROM default.src ORDER BY key LIMIT 1")
+
+ createQueryTest("alias.attr",
+ "SELECT a.key FROM src a ORDER BY key LIMIT 1")
+
+ createQueryTest("subquery-alias.attr",
+ "SELECT a.key FROM (SELECT * FROM src ORDER BY key LIMIT 1) a")
+
+ createQueryTest("quoted alias.attr",
+ "SELECT `a`.`key` FROM src a ORDER BY key LIMIT 1")
+
+ createQueryTest("attr",
+ "SELECT key FROM src a ORDER BY key LIMIT 1")
+
+ createQueryTest("alias.*",
+ "SELECT a.* FROM src a ORDER BY key LIMIT 1")
+
+ /**
+ * Negative examples. Currently only left here for documentation purposes.
+ * TODO(marmbrus): Test that catalyst fails on these queries.
+ */
+
+ /* SemanticException [Error 10009]: Line 1:7 Invalid table alias 'src'
+ createQueryTest("table.*",
+ "SELECT src.* FROM src a ORDER BY key LIMIT 1") */
+
+ /* Invalid table alias or column reference 'src': (possible column names are: key, value)
+ createQueryTest("tableName.attr from aliased subquery",
+ "SELECT src.key FROM (SELECT * FROM src ORDER BY key LIMIT 1) a") */
+
+} \ No newline at end of file
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
new file mode 100644
index 0000000000..c2264926f4
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+package execution
+
+/**
+ * A set of tests that validates support for Hive SerDe.
+ */
+class HiveSerDeSuite extends HiveComparisonTest {
+ createQueryTest(
+ "Read and write with LazySimpleSerDe (tab separated)",
+ "SELECT * from serdeins")
+
+ createQueryTest("Read with RegexSerDe", "SELECT * FROM sales")
+
+ createQueryTest("Read with AvroSerDe", "SELECT * FROM episodes")
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala
new file mode 100644
index 0000000000..bb33583e5f
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+/**
+ * A set of tests that validate type promotion rules.
+ */
+class HiveTypeCoercionSuite extends HiveComparisonTest {
+
+ val baseTypes = Seq("1", "1.0", "1L", "1S", "1Y", "'1'")
+
+ baseTypes.foreach { i =>
+ baseTypes.foreach { j =>
+ createQueryTest(s"$i + $j", s"SELECT $i + $j FROM src LIMIT 1")
+ }
+ }
+} \ No newline at end of file
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala
new file mode 100644
index 0000000000..8542f42aa9
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+package execution
+
+import scala.collection.JavaConversions._
+
+import org.apache.spark.sql.hive.TestHive
+
+/**
+ * A set of test cases that validate partition and column pruning.
+ */
+class PruningSuite extends HiveComparisonTest {
+ // Column pruning tests
+
+ createPruningTest("Column pruning: with partitioned table",
+ "SELECT key FROM srcpart WHERE ds = '2008-04-08' LIMIT 3",
+ Seq("key"),
+ Seq("key", "ds"),
+ Seq(
+ Seq("2008-04-08", "11"),
+ Seq("2008-04-08", "12")))
+
+ createPruningTest("Column pruning: with non-partitioned table",
+ "SELECT key FROM src WHERE key > 10 LIMIT 3",
+ Seq("key"),
+ Seq("key"),
+ Seq.empty)
+
+ createPruningTest("Column pruning: with multiple projects",
+ "SELECT c1 FROM (SELECT key AS c1 FROM src WHERE key > 10) t1 LIMIT 3",
+ Seq("c1"),
+ Seq("key"),
+ Seq.empty)
+
+ createPruningTest("Column pruning: projects alias substituting",
+ "SELECT c1 AS c2 FROM (SELECT key AS c1 FROM src WHERE key > 10) t1 LIMIT 3",
+ Seq("c2"),
+ Seq("key"),
+ Seq.empty)
+
+ createPruningTest("Column pruning: filter alias in-lining",
+ "SELECT c1 FROM (SELECT key AS c1 FROM src WHERE key > 10) t1 WHERE c1 < 100 LIMIT 3",
+ Seq("c1"),
+ Seq("key"),
+ Seq.empty)
+
+ createPruningTest("Column pruning: without filters",
+ "SELECT c1 FROM (SELECT key AS c1 FROM src) t1 LIMIT 3",
+ Seq("c1"),
+ Seq("key"),
+ Seq.empty)
+
+ createPruningTest("Column pruning: simple top project without aliases",
+ "SELECT key FROM (SELECT key FROM src WHERE key > 10) t1 WHERE key < 100 LIMIT 3",
+ Seq("key"),
+ Seq("key"),
+ Seq.empty)
+
+ createPruningTest("Column pruning: non-trivial top project with aliases",
+ "SELECT c1 * 2 AS double FROM (SELECT key AS c1 FROM src WHERE key > 10) t1 LIMIT 3",
+ Seq("double"),
+ Seq("key"),
+ Seq.empty)
+
+ // Partition pruning tests
+
+ createPruningTest("Partition pruning: non-partitioned, non-trivial project",
+ "SELECT key * 2 AS double FROM src WHERE value IS NOT NULL",
+ Seq("double"),
+ Seq("key", "value"),
+ Seq.empty)
+
+ createPruningTest("Partiton pruning: non-partitioned table",
+ "SELECT value FROM src WHERE key IS NOT NULL",
+ Seq("value"),
+ Seq("value", "key"),
+ Seq.empty)
+
+ createPruningTest("Partition pruning: with filter on string partition key",
+ "SELECT value, hr FROM srcpart1 WHERE ds = '2008-04-08'",
+ Seq("value", "hr"),
+ Seq("value", "hr", "ds"),
+ Seq(
+ Seq("2008-04-08", "11"),
+ Seq("2008-04-08", "12")))
+
+ createPruningTest("Partition pruning: with filter on int partition key",
+ "SELECT value, hr FROM srcpart1 WHERE hr < 12",
+ Seq("value", "hr"),
+ Seq("value", "hr"),
+ Seq(
+ Seq("2008-04-08", "11"),
+ Seq("2008-04-09", "11")))
+
+ createPruningTest("Partition pruning: left only 1 partition",
+ "SELECT value, hr FROM srcpart1 WHERE ds = '2008-04-08' AND hr < 12",
+ Seq("value", "hr"),
+ Seq("value", "hr", "ds"),
+ Seq(
+ Seq("2008-04-08", "11")))
+
+ createPruningTest("Partition pruning: all partitions pruned",
+ "SELECT value, hr FROM srcpart1 WHERE ds = '2014-01-27' AND hr = 11",
+ Seq("value", "hr"),
+ Seq("value", "hr", "ds"),
+ Seq.empty)
+
+ createPruningTest("Partition pruning: pruning with both column key and partition key",
+ "SELECT value, hr FROM srcpart1 WHERE value IS NOT NULL AND hr < 12",
+ Seq("value", "hr"),
+ Seq("value", "hr"),
+ Seq(
+ Seq("2008-04-08", "11"),
+ Seq("2008-04-09", "11")))
+
+ def createPruningTest(
+ testCaseName: String,
+ sql: String,
+ expectedOutputColumns: Seq[String],
+ expectedScannedColumns: Seq[String],
+ expectedPartValues: Seq[Seq[String]]) = {
+ test(s"$testCaseName - pruning test") {
+ val plan = new TestHive.SqlQueryExecution(sql).executedPlan
+ val actualOutputColumns = plan.output.map(_.name)
+ val (actualScannedColumns, actualPartValues) = plan.collect {
+ case p @ HiveTableScan(columns, relation, _) =>
+ val columnNames = columns.map(_.name)
+ val partValues = p.prunePartitions(relation.hiveQlPartitions).map(_.getValues)
+ (columnNames, partValues)
+ }.head
+
+ assert(actualOutputColumns sameElements expectedOutputColumns, "Output columns mismatch")
+ assert(actualScannedColumns sameElements expectedScannedColumns, "Scanned columns mismatch")
+
+ assert(
+ actualPartValues.length === expectedPartValues.length,
+ "Partition value count mismatches")
+
+ for ((actual, expected) <- actualPartValues.zip(expectedPartValues)) {
+ assert(actual sameElements expected, "Partition values mismatch")
+ }
+ }
+
+ // Creates a query test to compare query results generated by Hive and Catalyst.
+ createQueryTest(s"$testCaseName - query test", sql)
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala
new file mode 100644
index 0000000000..ee90061c7c
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.parquet
+
+import java.io.File
+
+import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.sql.catalyst.expressions.Row
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.util.getTempFilePath
+import org.apache.spark.sql.hive.TestHive
+
+
+class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach {
+
+ val filename = getTempFilePath("parquettest").getCanonicalFile.toURI.toString
+
+ // runs a SQL and optionally resolves one Parquet table
+ def runQuery(querystr: String, tableName: Option[String] = None, filename: Option[String] = None): Array[Row] = {
+ // call to resolve references in order to get CREATE TABLE AS to work
+ val query = TestHive
+ .parseSql(querystr)
+ val finalQuery =
+ if (tableName.nonEmpty && filename.nonEmpty)
+ resolveParquetTable(tableName.get, filename.get, query)
+ else
+ query
+ TestHive.executePlan(finalQuery)
+ .toRdd
+ .collect()
+ }
+
+ // stores a query output to a Parquet file
+ def storeQuery(querystr: String, filename: String): Unit = {
+ val query = WriteToFile(
+ filename,
+ TestHive.parseSql(querystr))
+ TestHive
+ .executePlan(query)
+ .stringResult()
+ }
+
+ /**
+ * TODO: This function is necessary as long as there is no notion of a Catalog for
+ * Parquet tables. Once such a thing exists this functionality should be moved there.
+ */
+ def resolveParquetTable(tableName: String, filename: String, plan: LogicalPlan): LogicalPlan = {
+ TestHive.loadTestTable("src") // may not be loaded now
+ plan.transform {
+ case relation @ UnresolvedRelation(databaseName, name, alias) =>
+ if (name == tableName)
+ ParquetRelation(tableName, filename)
+ else
+ relation
+ case op @ InsertIntoCreatedTable(databaseName, name, child) =>
+ if (name == tableName) {
+ // note: at this stage the plan is not yet analyzed but Parquet needs to know the schema
+ // and for that we need the child to be resolved
+ val relation = ParquetRelation.create(
+ filename,
+ TestHive.analyzer(child),
+ TestHive.sparkContext.hadoopConfiguration,
+ Some(tableName))
+ InsertIntoTable(
+ relation.asInstanceOf[BaseRelation],
+ Map.empty,
+ child,
+ overwrite = false)
+ } else
+ op
+ }
+ }
+
+ override def beforeAll() {
+ // write test data
+ ParquetTestData.writeFile
+ // Override initial Parquet test table
+ TestHive.catalog.registerTable(Some[String]("parquet"), "testsource", ParquetTestData.testData)
+ }
+
+ override def afterAll() {
+ ParquetTestData.testFile.delete()
+ }
+
+ override def beforeEach() {
+ new File(filename).getAbsoluteFile.delete()
+ }
+
+ override def afterEach() {
+ new File(filename).getAbsoluteFile.delete()
+ }
+
+ test("SELECT on Parquet table") {
+ val rdd = runQuery("SELECT * FROM parquet.testsource")
+ assert(rdd != null)
+ assert(rdd.forall(_.size == 6))
+ }
+
+ test("Simple column projection + filter on Parquet table") {
+ val rdd = runQuery("SELECT myboolean, mylong FROM parquet.testsource WHERE myboolean=true")
+ assert(rdd.size === 5, "Filter returned incorrect number of rows")
+ assert(rdd.forall(_.getBoolean(0)), "Filter returned incorrect Boolean field value")
+ }
+
+ test("Converting Hive to Parquet Table via WriteToFile") {
+ storeQuery("SELECT * FROM src", filename)
+ val rddOne = runQuery("SELECT * FROM src").sortBy(_.getInt(0))
+ val rddTwo = runQuery("SELECT * from ptable", Some("ptable"), Some(filename)).sortBy(_.getInt(0))
+ compareRDDs(rddOne, rddTwo, "src (Hive)", Seq("key:Int", "value:String"))
+ }
+
+ test("INSERT OVERWRITE TABLE Parquet table") {
+ storeQuery("SELECT * FROM parquet.testsource", filename)
+ runQuery("INSERT OVERWRITE TABLE ptable SELECT * FROM parquet.testsource", Some("ptable"), Some(filename))
+ runQuery("INSERT OVERWRITE TABLE ptable SELECT * FROM parquet.testsource", Some("ptable"), Some(filename))
+ val rddCopy = runQuery("SELECT * FROM ptable", Some("ptable"), Some(filename))
+ val rddOrig = runQuery("SELECT * FROM parquet.testsource")
+ compareRDDs(rddOrig, rddCopy, "parquet.testsource", ParquetTestData.testSchemaFieldNames)
+ }
+
+ test("CREATE TABLE AS Parquet table") {
+ runQuery("CREATE TABLE ptable AS SELECT * FROM src", Some("ptable"), Some(filename))
+ val rddCopy = runQuery("SELECT * FROM ptable", Some("ptable"), Some(filename))
+ .sortBy[Int](_.apply(0) match {
+ case x: Int => x
+ case _ => 0
+ })
+ val rddOrig = runQuery("SELECT * FROM src").sortBy(_.getInt(0))
+ compareRDDs(rddOrig, rddCopy, "src (Hive)", Seq("key:Int", "value:String"))
+ }
+
+ private def compareRDDs(rddOne: Array[Row], rddTwo: Array[Row], tableName: String, fieldNames: Seq[String]) {
+ var counter = 0
+ (rddOne, rddTwo).zipped.foreach {
+ (a,b) => (a,b).zipped.toArray.zipWithIndex.foreach {
+ case ((value_1:Array[Byte], value_2:Array[Byte]), index) =>
+ assert(new String(value_1) === new String(value_2), s"table $tableName row ${counter} field ${fieldNames(index)} don't match")
+ case ((value_1, value_2), index) =>
+ assert(value_1 === value_2, s"table $tableName row $counter field ${fieldNames(index)} don't match")
+ }
+ counter = counter + 1
+ }
+ }
+}