diff options
Diffstat (limited to 'sql/hive')
23 files changed, 5392 insertions, 0 deletions
diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml new file mode 100644 index 0000000000..7b5ea98f27 --- /dev/null +++ b/sql/hive/pom.xml @@ -0,0 +1,81 @@ +<?xml version="1.0" encoding="UTF-8"?> +<!-- + ~ Licensed to the Apache Software Foundation (ASF) under one or more + ~ contributor license agreements. See the NOTICE file distributed with + ~ this work for additional information regarding copyright ownership. + ~ The ASF licenses this file to You under the Apache License, Version 2.0 + ~ (the "License"); you may not use this file except in compliance with + ~ the License. You may obtain a copy of the License at + ~ + ~ http://www.apache.org/licenses/LICENSE-2.0 + ~ + ~ Unless required by applicable law or agreed to in writing, software + ~ distributed under the License is distributed on an "AS IS" BASIS, + ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + ~ See the License for the specific language governing permissions and + ~ limitations under the License. + --> + +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + <parent> + <groupId>org.apache.spark</groupId> + <artifactId>spark-parent</artifactId> + <version>1.0.0-SNAPSHOT</version> + <relativePath>../../pom.xml</relativePath> + </parent> + + <groupId>org.apache.spark</groupId> + <artifactId>spark-hive_2.10</artifactId> + <packaging>jar</packaging> + <name>Spark Project Hive</name> + <url>http://spark.apache.org/</url> + + <dependencies> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-core_${scala.binary.version}</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-sql_${scala.binary.version}</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> + <groupId>org.apache.hive</groupId> + <artifactId>hive-metastore</artifactId> + <version>${hive.version}</version> + </dependency> + <dependency> + <groupId>org.apache.hive</groupId> + <artifactId>hive-exec</artifactId> + <version>${hive.version}</version> + </dependency> + <dependency> + <groupId>org.apache.hive</groupId> + <artifactId>hive-serde</artifactId> + <version>${hive.version}</version> + </dependency> + <dependency> + <groupId>org.scalatest</groupId> + <artifactId>scalatest_${scala.binary.version}</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.scalacheck</groupId> + <artifactId>scalacheck_${scala.binary.version}</artifactId> + <scope>test</scope> + </dependency> + </dependencies> + <build> + <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory> + <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory> + <plugins> + <plugin> + <groupId>org.scalatest</groupId> + <artifactId>scalatest-maven-plugin</artifactId> + </plugin> + </plugins> + </build> +</project> diff --git a/sql/hive/src/main/scala/org/apache/hadoop/mapred/SparkHadoopWriter.scala b/sql/hive/src/main/scala/org/apache/hadoop/mapred/SparkHadoopWriter.scala new file mode 100644 index 0000000000..08d390e887 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/hadoop/mapred/SparkHadoopWriter.scala @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.mapred + +import java.io.IOException +import java.text.NumberFormat +import java.util.Date + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.Writable + +import org.apache.spark.Logging +import org.apache.spark.SerializableWritable + +import org.apache.hadoop.hive.ql.exec.{Utilities, FileSinkOperator} +import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} +import org.apache.hadoop.hive.ql.plan.FileSinkDesc + +/** + * Internal helper class that saves an RDD using a Hive OutputFormat. + * It is based on [[SparkHadoopWriter]]. + */ +protected[apache] +class SparkHiveHadoopWriter( + @transient jobConf: JobConf, + fileSinkConf: FileSinkDesc) + extends Logging + with SparkHadoopMapRedUtil + with Serializable { + + private val now = new Date() + private val conf = new SerializableWritable(jobConf) + + private var jobID = 0 + private var splitID = 0 + private var attemptID = 0 + private var jID: SerializableWritable[JobID] = null + private var taID: SerializableWritable[TaskAttemptID] = null + + @transient private var writer: FileSinkOperator.RecordWriter = null + @transient private var format: HiveOutputFormat[AnyRef, Writable] = null + @transient private var committer: OutputCommitter = null + @transient private var jobContext: JobContext = null + @transient private var taskContext: TaskAttemptContext = null + + def preSetup() { + setIDs(0, 0, 0) + setConfParams() + + val jCtxt = getJobContext() + getOutputCommitter().setupJob(jCtxt) + } + + + def setup(jobid: Int, splitid: Int, attemptid: Int) { + setIDs(jobid, splitid, attemptid) + setConfParams() + } + + def open() { + val numfmt = NumberFormat.getInstance() + numfmt.setMinimumIntegerDigits(5) + numfmt.setGroupingUsed(false) + + val extension = Utilities.getFileExtension( + conf.value, + fileSinkConf.getCompressed, + getOutputFormat()) + + val outputName = "part-" + numfmt.format(splitID) + extension + val path = FileOutputFormat.getTaskOutputPath(conf.value, outputName) + + getOutputCommitter().setupTask(getTaskContext()) + writer = HiveFileFormatUtils.getHiveRecordWriter( + conf.value, + fileSinkConf.getTableInfo, + conf.value.getOutputValueClass.asInstanceOf[Class[Writable]], + fileSinkConf, + path, + null) + } + + def write(value: Writable) { + if (writer != null) { + writer.write(value) + } else { + throw new IOException("Writer is null, open() has not been called") + } + } + + def close() { + // Seems the boolean value passed into close does not matter. + writer.close(false) + } + + def commit() { + val taCtxt = getTaskContext() + val cmtr = getOutputCommitter() + if (cmtr.needsTaskCommit(taCtxt)) { + try { + cmtr.commitTask(taCtxt) + logInfo (taID + ": Committed") + } catch { + case e: IOException => { + logError("Error committing the output of task: " + taID.value, e) + cmtr.abortTask(taCtxt) + throw e + } + } + } else { + logWarning ("No need to commit output of task: " + taID.value) + } + } + + def commitJob() { + // always ? Or if cmtr.needsTaskCommit ? + val cmtr = getOutputCommitter() + cmtr.commitJob(getJobContext()) + } + + // ********* Private Functions ********* + + private def getOutputFormat(): HiveOutputFormat[AnyRef,Writable] = { + if (format == null) { + format = conf.value.getOutputFormat() + .asInstanceOf[HiveOutputFormat[AnyRef,Writable]] + } + format + } + + private def getOutputCommitter(): OutputCommitter = { + if (committer == null) { + committer = conf.value.getOutputCommitter + } + committer + } + + private def getJobContext(): JobContext = { + if (jobContext == null) { + jobContext = newJobContext(conf.value, jID.value) + } + jobContext + } + + private def getTaskContext(): TaskAttemptContext = { + if (taskContext == null) { + taskContext = newTaskAttemptContext(conf.value, taID.value) + } + taskContext + } + + private def setIDs(jobid: Int, splitid: Int, attemptid: Int) { + jobID = jobid + splitID = splitid + attemptID = attemptid + + jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobid)) + taID = new SerializableWritable[TaskAttemptID]( + new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID)) + } + + private def setConfParams() { + conf.value.set("mapred.job.id", jID.value.toString) + conf.value.set("mapred.tip.id", taID.value.getTaskID.toString) + conf.value.set("mapred.task.id", taID.value.toString) + conf.value.setBoolean("mapred.task.is.map", true) + conf.value.setInt("mapred.task.partition", splitID) + } +} + +object SparkHiveHadoopWriter { + def createPathFromString(path: String, conf: JobConf): Path = { + if (path == null) { + throw new IllegalArgumentException("Output path is null") + } + val outputPath = new Path(path) + val fs = outputPath.getFileSystem(conf) + if (outputPath == null || fs == null) { + throw new IllegalArgumentException("Incorrectly formatted output path") + } + outputPath.makeQualified(fs) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala new file mode 100644 index 0000000000..4aad876cc0 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive + +import java.io.{PrintStream, InputStreamReader, BufferedReader, File} +import java.util.{ArrayList => JArrayList} +import scala.language.implicitConversions + +import org.apache.spark.SparkContext +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.ql.processors.{CommandProcessorResponse, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.processors.CommandProcessor +import org.apache.hadoop.hive.ql.Driver +import org.apache.spark.rdd.RDD + +import catalyst.analysis.{Analyzer, OverrideCatalog} +import catalyst.expressions.GenericRow +import catalyst.plans.logical.{BaseRelation, LogicalPlan, NativeCommand, ExplainCommand} +import catalyst.types._ + +import org.apache.spark.sql.execution._ + +import scala.collection.JavaConversions._ + +/** + * Starts up an instance of hive where metadata is stored locally. An in-process metadata data is + * created with data stored in ./metadata. Warehouse data is stored in in ./warehouse. + */ +class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) { + + lazy val metastorePath = new File("metastore").getCanonicalPath + lazy val warehousePath: String = new File("warehouse").getCanonicalPath + + /** Sets up the system initially or after a RESET command */ + protected def configure() { + // TODO: refactor this so we can work with other databases. + runSqlHive( + s"set javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$metastorePath;create=true") + runSqlHive("set hive.metastore.warehouse.dir=" + warehousePath) + } + + configure() // Must be called before initializing the catalog below. +} + +/** + * An instance of the Spark SQL execution engine that integrates with data stored in Hive. + * Configuration for Hive is read from hive-site.xml on the classpath. + */ +class HiveContext(sc: SparkContext) extends SQLContext(sc) { + self => + + override def parseSql(sql: String): LogicalPlan = HiveQl.parseSql(sql) + override def executePlan(plan: LogicalPlan): this.QueryExecution = + new this.QueryExecution { val logical = plan } + + // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. + @transient + protected val outputBuffer = new java.io.OutputStream { + var pos: Int = 0 + var buffer = new Array[Int](10240) + def write(i: Int): Unit = { + buffer(pos) = i + pos = (pos + 1) % buffer.size + } + + override def toString = { + val (end, start) = buffer.splitAt(pos) + val input = new java.io.InputStream { + val iterator = (start ++ end).iterator + + def read(): Int = if (iterator.hasNext) iterator.next else -1 + } + val reader = new BufferedReader(new InputStreamReader(input)) + val stringBuilder = new StringBuilder + var line = reader.readLine() + while(line != null) { + stringBuilder.append(line) + stringBuilder.append("\n") + line = reader.readLine() + } + stringBuilder.toString() + } + } + + @transient protected[hive] lazy val hiveconf = new HiveConf(classOf[SessionState]) + @transient protected[hive] lazy val sessionState = new SessionState(hiveconf) + + sessionState.err = new PrintStream(outputBuffer, true, "UTF-8") + sessionState.out = new PrintStream(outputBuffer, true, "UTF-8") + + /* A catalyst metadata catalog that points to the Hive Metastore. */ + @transient + override lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog + + /* An analyzer that uses the Hive metastore. */ + @transient + override lazy val analyzer = new Analyzer(catalog, HiveFunctionRegistry, caseSensitive = false) + + def tables: Seq[BaseRelation] = { + // TODO: Move this functionallity to Catalog. Make client protected. + val allTables = catalog.client.getAllTables("default") + allTables.map(catalog.lookupRelation(None, _, None)).collect { case b: BaseRelation => b } + } + + /** + * Runs the specified SQL query using Hive. + */ + protected def runSqlHive(sql: String): Seq[String] = { + val maxResults = 100000 + val results = runHive(sql, 100000) + // It is very confusing when you only get back some of the results... + if (results.size == maxResults) sys.error("RESULTS POSSIBLY TRUNCATED") + results + } + + // TODO: Move this. + + SessionState.start(sessionState) + + /** + * Execute the command using Hive and return the results as a sequence. Each element + * in the sequence is one row. + */ + protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = { + try { + val cmd_trimmed: String = cmd.trim() + val tokens: Array[String] = cmd_trimmed.split("\\s+") + val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() + val proc: CommandProcessor = CommandProcessorFactory.get(tokens(0), hiveconf) + + SessionState.start(sessionState) + + if (proc.isInstanceOf[Driver]) { + val driver: Driver = proc.asInstanceOf[Driver] + driver.init() + + val results = new JArrayList[String] + val response: CommandProcessorResponse = driver.run(cmd) + // Throw an exception if there is an error in query processing. + if (response.getResponseCode != 0) { + driver.destroy() + throw new QueryExecutionException(response.getErrorMessage) + } + driver.setMaxRows(maxRows) + driver.getResults(results) + driver.destroy() + results + } else { + sessionState.out.println(tokens(0) + " " + cmd_1) + Seq(proc.run(cmd_1).getResponseCode.toString) + } + } catch { + case e: Exception => + logger.error( + s""" + |====================== + |HIVE FAILURE OUTPUT + |====================== + |${outputBuffer.toString} + |====================== + |END HIVE FAILURE OUTPUT + |====================== + """.stripMargin) + throw e + } + } + + @transient + val hivePlanner = new SparkPlanner with HiveStrategies { + val hiveContext = self + + override val strategies: Seq[Strategy] = Seq( + TopK, + ColumnPrunings, + PartitionPrunings, + HiveTableScans, + DataSinks, + Scripts, + PartialAggregation, + SparkEquiInnerJoin, + BasicOperators, + CartesianProduct, + BroadcastNestedLoopJoin + ) + } + + @transient + override val planner = hivePlanner + + @transient + protected lazy val emptyResult = + sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1) + + /** Extends QueryExecution with hive specific features. */ + abstract class QueryExecution extends super.QueryExecution { + // TODO: Create mixin for the analyzer instead of overriding things here. + override lazy val optimizedPlan = + optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed))) + + // TODO: We are loosing schema here. + override lazy val toRdd: RDD[Row] = + analyzed match { + case NativeCommand(cmd) => + val output = runSqlHive(cmd) + + if (output.size == 0) { + emptyResult + } else { + val asRows = output.map(r => new GenericRow(r.split("\t").asInstanceOf[Array[Any]])) + sparkContext.parallelize(asRows, 1) + } + case _ => + executedPlan.execute.map(_.copy()) + } + + protected val primitiveTypes = + Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, + ShortType, DecimalType) + + protected def toHiveString(a: (Any, DataType)): String = a match { + case (struct: Row, StructType(fields)) => + struct.zip(fields).map { + case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" + }.mkString("{", ",", "}") + case (seq: Seq[_], ArrayType(typ))=> + seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") + case (map: Map[_,_], MapType(kType, vType)) => + map.map { + case (key, value) => + toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) + }.toSeq.sorted.mkString("{", ",", "}") + case (null, _) => "NULL" + case (other, tpe) if primitiveTypes contains tpe => other.toString + } + + /** Hive outputs fields of structs slightly differently than top level attributes. */ + protected def toHiveStructString(a: (Any, DataType)): String = a match { + case (struct: Row, StructType(fields)) => + struct.zip(fields).map { + case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" + }.mkString("{", ",", "}") + case (seq: Seq[_], ArrayType(typ))=> + seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") + case (map: Map[_,_], MapType(kType, vType)) => + map.map { + case (key, value) => + toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) + }.toSeq.sorted.mkString("{", ",", "}") + case (null, _) => "null" + case (s: String, StringType) => "\"" + s + "\"" + case (other, tpe) if primitiveTypes contains tpe => other.toString + } + + /** + * Returns the result as a hive compatible sequence of strings. For native commands, the + * execution is simply passed back to Hive. + */ + def stringResult(): Seq[String] = analyzed match { + case NativeCommand(cmd) => runSqlHive(cmd) + case ExplainCommand(plan) => new QueryExecution { val logical = plan }.toString.split("\n") + case query => + val result: Seq[Seq[Any]] = toRdd.collect().toSeq + // We need the types so we can output struct field names + val types = analyzed.output.map(_.dataType) + // Reformat to match hive tab delimited output. + val asString = result.map(_.zip(types).map(toHiveString)).map(_.mkString("\t")).toSeq + asString + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala new file mode 100644 index 0000000000..e4d50722ce --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive + +import scala.util.parsing.combinator.RegexParsers + +import org.apache.hadoop.hive.metastore.api.{FieldSchema, StorageDescriptor, SerDeInfo} +import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition} +import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} +import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.serde2.Deserializer + +import catalyst.analysis.Catalog +import catalyst.expressions._ +import catalyst.plans.logical +import catalyst.plans.logical._ +import catalyst.rules._ +import catalyst.types._ + +import scala.collection.JavaConversions._ + +class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Logging { + import HiveMetastoreTypes._ + + val client = Hive.get(hive.hiveconf) + + def lookupRelation( + db: Option[String], + tableName: String, + alias: Option[String]): LogicalPlan = { + val databaseName = db.getOrElse(hive.sessionState.getCurrentDatabase()) + val table = client.getTable(databaseName, tableName) + val partitions: Seq[Partition] = + if (table.isPartitioned) { + client.getPartitions(table) + } else { + Nil + } + + // Since HiveQL is case insensitive for table names we make them all lowercase. + MetastoreRelation( + databaseName.toLowerCase, + tableName.toLowerCase, + alias)(table.getTTable, partitions.map(part => part.getTPartition)) + } + + def createTable(databaseName: String, tableName: String, schema: Seq[Attribute]) { + val table = new Table(databaseName, tableName) + val hiveSchema = + schema.map(attr => new FieldSchema(attr.name, toMetastoreType(attr.dataType), "")) + table.setFields(hiveSchema) + + val sd = new StorageDescriptor() + table.getTTable.setSd(sd) + sd.setCols(hiveSchema) + + // TODO: THESE ARE ALL DEFAULTS, WE NEED TO PARSE / UNDERSTAND the output specs. + sd.setCompressed(false) + sd.setParameters(Map[String, String]()) + sd.setInputFormat("org.apache.hadoop.mapred.TextInputFormat") + sd.setOutputFormat("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat") + val serDeInfo = new SerDeInfo() + serDeInfo.setName(tableName) + serDeInfo.setSerializationLib("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") + serDeInfo.setParameters(Map[String, String]()) + sd.setSerdeInfo(serDeInfo) + client.createTable(table) + } + + /** + * Creates any tables required for query execution. + * For example, because of a CREATE TABLE X AS statement. + */ + object CreateTables extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case InsertIntoCreatedTable(db, tableName, child) => + val databaseName = db.getOrElse(SessionState.get.getCurrentDatabase()) + + createTable(databaseName, tableName, child.output) + + InsertIntoTable( + lookupRelation(Some(databaseName), tableName, None).asInstanceOf[BaseRelation], + Map.empty, + child, + overwrite = false) + } + } + + /** + * Casts input data to correct data types according to table definition before inserting into + * that table. + */ + object PreInsertionCasts extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan.transform { + // Wait until children are resolved + case p: LogicalPlan if !p.childrenResolved => p + + case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) => + val childOutputDataTypes = child.output.map(_.dataType) + // Only check attributes, not partitionKeys since they are always strings. + // TODO: Fully support inserting into partitioned tables. + val tableOutputDataTypes = table.attributes.map(_.dataType) + + if (childOutputDataTypes == tableOutputDataTypes) { + p + } else { + // Only do the casting when child output data types differ from table output data types. + val castedChildOutput = child.output.zip(table.output).map { + case (input, table) if input.dataType != table.dataType => + Alias(Cast(input, table.dataType), input.name)() + case (input, _) => input + } + + p.copy(child = logical.Project(castedChildOutput, child)) + } + } + } + + /** + * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. + * For now, if this functionallity is desired mix in the in-memory [[OverrideCatalog]]. + */ + override def registerTable( + databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = ??? +} + +object HiveMetastoreTypes extends RegexParsers { + protected lazy val primitiveType: Parser[DataType] = + "string" ^^^ StringType | + "float" ^^^ FloatType | + "int" ^^^ IntegerType | + "tinyint" ^^^ ShortType | + "double" ^^^ DoubleType | + "bigint" ^^^ LongType | + "binary" ^^^ BinaryType | + "boolean" ^^^ BooleanType | + "decimal" ^^^ DecimalType | + "varchar\\((\\d+)\\)".r ^^^ StringType + + protected lazy val arrayType: Parser[DataType] = + "array" ~> "<" ~> dataType <~ ">" ^^ ArrayType + + protected lazy val mapType: Parser[DataType] = + "map" ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ { + case t1 ~ _ ~ t2 => MapType(t1, t2) + } + + protected lazy val structField: Parser[StructField] = + "[a-zA-Z0-9]*".r ~ ":" ~ dataType ^^ { + case name ~ _ ~ tpe => StructField(name, tpe, nullable = true) + } + + protected lazy val structType: Parser[DataType] = + "struct" ~> "<" ~> repsep(structField,",") <~ ">" ^^ StructType + + protected lazy val dataType: Parser[DataType] = + arrayType | + mapType | + structType | + primitiveType + + def toDataType(metastoreType: String): DataType = parseAll(dataType, metastoreType) match { + case Success(result, _) => result + case failure: NoSuccess => sys.error(s"Unsupported dataType: $metastoreType") + } + + def toMetastoreType(dt: DataType): String = dt match { + case ArrayType(elementType) => s"array<${toMetastoreType(elementType)}>" + case StructType(fields) => + s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>" + case MapType(keyType, valueType) => + s"map<${toMetastoreType(keyType)},${toMetastoreType(valueType)}>" + case StringType => "string" + case FloatType => "float" + case IntegerType => "int" + case ShortType =>"tinyint" + case DoubleType => "double" + case LongType => "bigint" + case BinaryType => "binary" + case BooleanType => "boolean" + case DecimalType => "decimal" + } +} + +case class MetastoreRelation(databaseName: String, tableName: String, alias: Option[String]) + (val table: TTable, val partitions: Seq[TPartition]) + extends BaseRelation { + // TODO: Can we use org.apache.hadoop.hive.ql.metadata.Table as the type of table and + // use org.apache.hadoop.hive.ql.metadata.Partition as the type of elements of partitions. + // Right now, using org.apache.hadoop.hive.ql.metadata.Table and + // org.apache.hadoop.hive.ql.metadata.Partition will cause a NotSerializableException + // which indicates the SerDe we used is not Serializable. + + def hiveQlTable = new Table(table) + + def hiveQlPartitions = partitions.map { p => + new Partition(hiveQlTable, p) + } + + override def isPartitioned = hiveQlTable.isPartitioned + + val tableDesc = new TableDesc( + Class.forName(hiveQlTable.getSerializationLib).asInstanceOf[Class[Deserializer]], + hiveQlTable.getInputFormatClass, + // The class of table should be org.apache.hadoop.hive.ql.metadata.Table because + // getOutputFormatClass will use HiveFileFormatUtils.getOutputFormatSubstitute to + // substitute some output formats, e.g. substituting SequenceFileOutputFormat to + // HiveSequenceFileOutputFormat. + hiveQlTable.getOutputFormatClass, + hiveQlTable.getMetadata + ) + + implicit class SchemaAttribute(f: FieldSchema) { + def toAttribute = AttributeReference( + f.getName, + HiveMetastoreTypes.toDataType(f.getType), + // Since data can be dumped in randomly with no validation, everything is nullable. + nullable = true + )(qualifiers = tableName +: alias.toSeq) + } + + // Must be a stable value since new attributes are born here. + val partitionKeys = hiveQlTable.getPartitionKeys.map(_.toAttribute) + + /** Non-partitionKey attributes */ + val attributes = table.getSd.getCols.map(_.toAttribute) + + val output = attributes ++ partitionKeys +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala new file mode 100644 index 0000000000..4f33a293c3 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -0,0 +1,966 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.hive.ql.lib.Node +import org.apache.hadoop.hive.ql.parse._ +import org.apache.hadoop.hive.ql.plan.PlanUtils + +import catalyst.analysis._ +import catalyst.expressions._ +import catalyst.plans._ +import catalyst.plans.logical +import catalyst.plans.logical._ +import catalyst.types._ + +/** + * Used when we need to start parsing the AST before deciding that we are going to pass the command + * back for Hive to execute natively. Will be replaced with a native command that contains the + * cmd string. + */ +case object NativePlaceholder extends Command + +case class DfsCommand(cmd: String) extends Command + +case class ShellCommand(cmd: String) extends Command + +case class SourceCommand(filePath: String) extends Command + +case class AddJar(jarPath: String) extends Command + +case class AddFile(filePath: String) extends Command + +/** Provides a mapping from HiveQL statments to catalyst logical plans and expression trees. */ +object HiveQl { + protected val nativeCommands = Seq( + "TOK_DESCFUNCTION", + "TOK_DESCTABLE", + "TOK_DESCDATABASE", + "TOK_SHOW_TABLESTATUS", + "TOK_SHOWDATABASES", + "TOK_SHOWFUNCTIONS", + "TOK_SHOWINDEXES", + "TOK_SHOWINDEXES", + "TOK_SHOWPARTITIONS", + "TOK_SHOWTABLES", + + "TOK_LOCKTABLE", + "TOK_SHOWLOCKS", + "TOK_UNLOCKTABLE", + + "TOK_CREATEROLE", + "TOK_DROPROLE", + "TOK_GRANT", + "TOK_GRANT_ROLE", + "TOK_REVOKE", + "TOK_SHOW_GRANT", + "TOK_SHOW_ROLE_GRANT", + + "TOK_CREATEFUNCTION", + "TOK_DROPFUNCTION", + + "TOK_ANALYZE", + "TOK_ALTERDATABASE_PROPERTIES", + "TOK_ALTERINDEX_PROPERTIES", + "TOK_ALTERINDEX_REBUILD", + "TOK_ALTERTABLE_ADDCOLS", + "TOK_ALTERTABLE_ADDPARTS", + "TOK_ALTERTABLE_ALTERPARTS", + "TOK_ALTERTABLE_ARCHIVE", + "TOK_ALTERTABLE_CLUSTER_SORT", + "TOK_ALTERTABLE_DROPPARTS", + "TOK_ALTERTABLE_PARTITION", + "TOK_ALTERTABLE_PROPERTIES", + "TOK_ALTERTABLE_RENAME", + "TOK_ALTERTABLE_RENAMECOL", + "TOK_ALTERTABLE_REPLACECOLS", + "TOK_ALTERTABLE_SKEWED", + "TOK_ALTERTABLE_TOUCH", + "TOK_ALTERTABLE_UNARCHIVE", + "TOK_ANALYZE", + "TOK_CREATEDATABASE", + "TOK_CREATEFUNCTION", + "TOK_CREATEINDEX", + "TOK_DROPDATABASE", + "TOK_DROPINDEX", + "TOK_DROPTABLE", + "TOK_MSCK", + + // TODO(marmbrus): Figure out how view are expanded by hive, as we might need to handle this. + "TOK_ALTERVIEW_ADDPARTS", + "TOK_ALTERVIEW_AS", + "TOK_ALTERVIEW_DROPPARTS", + "TOK_ALTERVIEW_PROPERTIES", + "TOK_ALTERVIEW_RENAME", + "TOK_CREATEVIEW", + "TOK_DROPVIEW", + + "TOK_EXPORT", + "TOK_IMPORT", + "TOK_LOAD", + + "TOK_SWITCHDATABASE" + ) + + /** + * A set of implicit transformations that allow Hive ASTNodes to be rewritten by transformations + * similar to [[catalyst.trees.TreeNode]]. + * + * Note that this should be considered very experimental and is not indented as a replacement + * for TreeNode. Primarily it should be noted ASTNodes are not immutable and do not appear to + * have clean copy semantics. Therefore, users of this class should take care when + * copying/modifying trees that might be used elsewhere. + */ + implicit class TransformableNode(n: ASTNode) { + /** + * Returns a copy of this node where `rule` has been recursively applied to it and all of its + * children. When `rule` does not apply to a given node it is left unchanged. + * @param rule the function use to transform this nodes children + */ + def transform(rule: PartialFunction[ASTNode, ASTNode]): ASTNode = { + try { + val afterRule = rule.applyOrElse(n, identity[ASTNode]) + afterRule.withChildren( + nilIfEmpty(afterRule.getChildren) + .asInstanceOf[Seq[ASTNode]] + .map(ast => Option(ast).map(_.transform(rule)).orNull)) + } catch { + case e: Exception => + println(dumpTree(n)) + throw e + } + } + + /** + * Returns a scala.Seq equivilent to [s] or Nil if [s] is null. + */ + private def nilIfEmpty[A](s: java.util.List[A]): Seq[A] = + Option(s).map(_.toSeq).getOrElse(Nil) + + /** + * Returns this ASTNode with the text changed to `newText``. + */ + def withText(newText: String): ASTNode = { + n.token.asInstanceOf[org.antlr.runtime.CommonToken].setText(newText) + n + } + + /** + * Returns this ASTNode with the children changed to `newChildren`. + */ + def withChildren(newChildren: Seq[ASTNode]): ASTNode = { + (1 to n.getChildCount).foreach(_ => n.deleteChild(0)) + n.addChildren(newChildren) + n + } + + /** + * Throws an error if this is not equal to other. + * + * Right now this function only checks the name, type, text and children of the node + * for equality. + */ + def checkEquals(other: ASTNode) { + def check(field: String, f: ASTNode => Any) = if (f(n) != f(other)) { + sys.error(s"$field does not match for trees. " + + s"'${f(n)}' != '${f(other)}' left: ${dumpTree(n)}, right: ${dumpTree(other)}") + } + check("name", _.getName) + check("type", _.getType) + check("text", _.getText) + check("numChildren", n => nilIfEmpty(n.getChildren).size) + + val leftChildren = nilIfEmpty(n.getChildren).asInstanceOf[Seq[ASTNode]] + val rightChildren = nilIfEmpty(other.getChildren).asInstanceOf[Seq[ASTNode]] + leftChildren zip rightChildren foreach { + case (l,r) => l checkEquals r + } + } + } + + class ParseException(sql: String, cause: Throwable) + extends Exception(s"Failed to parse: $sql", cause) + + /** + * Returns the AST for the given SQL string. + */ + def getAst(sql: String): ASTNode = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql)) + + /** Returns a LogicalPlan for a given HiveQL string. */ + def parseSql(sql: String): LogicalPlan = { + try { + if (sql.toLowerCase.startsWith("set")) { + NativeCommand(sql) + } else if (sql.toLowerCase.startsWith("add jar")) { + AddJar(sql.drop(8)) + } else if (sql.toLowerCase.startsWith("add file")) { + AddFile(sql.drop(9)) + } else if (sql.startsWith("dfs")) { + DfsCommand(sql) + } else if (sql.startsWith("source")) { + SourceCommand(sql.split(" ").toSeq match { case Seq("source", filePath) => filePath }) + } else if (sql.startsWith("!")) { + ShellCommand(sql.drop(1)) + } else { + val tree = getAst(sql) + + if (nativeCommands contains tree.getText) { + NativeCommand(sql) + } else { + nodeToPlan(tree) match { + case NativePlaceholder => NativeCommand(sql) + case other => other + } + } + } + } catch { + case e: Exception => throw new ParseException(sql, e) + } + } + + def parseDdl(ddl: String): Seq[Attribute] = { + val tree = + try { + ParseUtils.findRootNonNullToken( + (new ParseDriver).parse(ddl, null /* no context required for parsing alone */)) + } catch { + case pe: org.apache.hadoop.hive.ql.parse.ParseException => + throw new RuntimeException(s"Failed to parse ddl: '$ddl'", pe) + } + assert(tree.asInstanceOf[ASTNode].getText == "TOK_CREATETABLE", "Only CREATE TABLE supported.") + val tableOps = tree.getChildren + val colList = + tableOps + .find(_.asInstanceOf[ASTNode].getText == "TOK_TABCOLLIST") + .getOrElse(sys.error("No columnList!")).getChildren + + colList.map(nodeToAttribute) + } + + /** Extractor for matching Hive's AST Tokens. */ + object Token { + /** @return matches of the form (tokenName, children). */ + def unapply(t: Any): Option[(String, Seq[ASTNode])] = t match { + case t: ASTNode => + Some((t.getText, + Option(t.getChildren).map(_.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]])) + case _ => None + } + } + + protected def getClauses(clauseNames: Seq[String], nodeList: Seq[ASTNode]): Seq[Option[Node]] = { + var remainingNodes = nodeList + val clauses = clauseNames.map { clauseName => + val (matches, nonMatches) = remainingNodes.partition(_.getText.toUpperCase == clauseName) + remainingNodes = nonMatches ++ (if (matches.nonEmpty) matches.tail else Nil) + matches.headOption + } + + assert(remainingNodes.isEmpty, + s"Unhandled clauses: ${remainingNodes.map(dumpTree(_)).mkString("\n")}") + clauses + } + + def getClause(clauseName: String, nodeList: Seq[Node]) = + getClauseOption(clauseName, nodeList).getOrElse(sys.error( + s"Expected clause $clauseName missing from ${nodeList.map(dumpTree(_)).mkString("\n")}")) + + def getClauseOption(clauseName: String, nodeList: Seq[Node]): Option[Node] = { + nodeList.filter { case ast: ASTNode => ast.getText == clauseName } match { + case Seq(oneMatch) => Some(oneMatch) + case Seq() => None + case _ => sys.error(s"Found multiple instances of clause $clauseName") + } + } + + protected def nodeToAttribute(node: Node): Attribute = node match { + case Token("TOK_TABCOL", Token(colName, Nil) :: dataType :: Nil) => + AttributeReference(colName, nodeToDataType(dataType), true)() + + case a: ASTNode => + throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + } + + protected def nodeToDataType(node: Node): DataType = node match { + case Token("TOK_BIGINT", Nil) => IntegerType + case Token("TOK_INT", Nil) => IntegerType + case Token("TOK_TINYINT", Nil) => IntegerType + case Token("TOK_SMALLINT", Nil) => IntegerType + case Token("TOK_BOOLEAN", Nil) => BooleanType + case Token("TOK_STRING", Nil) => StringType + case Token("TOK_FLOAT", Nil) => FloatType + case Token("TOK_DOUBLE", Nil) => FloatType + case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType)) + case Token("TOK_STRUCT", + Token("TOK_TABCOLLIST", fields) :: Nil) => + StructType(fields.map(nodeToStructField)) + case Token("TOK_MAP", + keyType :: + valueType :: Nil) => + MapType(nodeToDataType(keyType), nodeToDataType(valueType)) + case a: ASTNode => + throw new NotImplementedError(s"No parse rules for DataType:\n ${dumpTree(a).toString} ") + } + + protected def nodeToStructField(node: Node): StructField = node match { + case Token("TOK_TABCOL", + Token(fieldName, Nil) :: + dataType :: Nil) => + StructField(fieldName, nodeToDataType(dataType), nullable = true) + case Token("TOK_TABCOL", + Token(fieldName, Nil) :: + dataType :: + _ /* comment */:: Nil) => + StructField(fieldName, nodeToDataType(dataType), nullable = true) + case a: ASTNode => + throw new NotImplementedError(s"No parse rules for StructField:\n ${dumpTree(a).toString} ") + } + + protected def nameExpressions(exprs: Seq[Expression]): Seq[NamedExpression] = { + exprs.zipWithIndex.map { + case (ne: NamedExpression, _) => ne + case (e, i) => Alias(e, s"c_$i")() + } + } + + protected def nodeToPlan(node: Node): LogicalPlan = node match { + // Just fake explain for any of the native commands. + case Token("TOK_EXPLAIN", explainArgs) if nativeCommands contains explainArgs.head.getText => + NoRelation + case Token("TOK_EXPLAIN", explainArgs) => + // Ignore FORMATTED if present. + val Some(query) :: _ :: _ :: Nil = + getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs) + // TODO: support EXTENDED? + ExplainCommand(nodeToPlan(query)) + + case Token("TOK_CREATETABLE", children) + if children.collect { case t@Token("TOK_QUERY", _) => t }.nonEmpty => + // TODO: Parse other clauses. + // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL + val ( + Some(tableNameParts) :: + _ /* likeTable */ :: + Some(query) +: + notImplemented) = + getClauses( + Seq( + "TOK_TABNAME", + "TOK_LIKETABLE", + "TOK_QUERY", + "TOK_IFNOTEXISTS", + "TOK_TABLECOMMENT", + "TOK_TABCOLLIST", + "TOK_TABLEPARTCOLS", // Partitioned by + "TOK_TABLEBUCKETS", // Clustered by + "TOK_TABLESKEWED", // Skewed by + "TOK_TABLEROWFORMAT", + "TOK_TABLESERIALIZER", + "TOK_FILEFORMAT_GENERIC", // For file formats not natively supported by Hive. + "TOK_TBLSEQUENCEFILE", // Stored as SequenceFile + "TOK_TBLTEXTFILE", // Stored as TextFile + "TOK_TBLRCFILE", // Stored as RCFile + "TOK_TBLORCFILE", // Stored as ORC File + "TOK_TABLEFILEFORMAT", // User-provided InputFormat and OutputFormat + "TOK_STORAGEHANDLER", // Storage handler + "TOK_TABLELOCATION", + "TOK_TABLEPROPERTIES"), + children) + if (notImplemented.exists(token => !token.isEmpty)) { + throw new NotImplementedError( + s"Unhandled clauses: ${notImplemented.flatten.map(dumpTree(_)).mkString("\n")}") + } + + val (db, tableName) = + tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match { + case Seq(tableOnly) => (None, tableOnly) + case Seq(databaseName, table) => (Some(databaseName), table) + } + InsertIntoCreatedTable(db, tableName, nodeToPlan(query)) + + // If its not a "CREATE TABLE AS" like above then just pass it back to hive as a native command. + case Token("TOK_CREATETABLE", _) => NativePlaceholder + + case Token("TOK_QUERY", + Token("TOK_FROM", fromClause :: Nil) :: + insertClauses) => + + // Return one query for each insert clause. + val queries = insertClauses.map { case Token("TOK_INSERT", singleInsert) => + val ( + intoClause :: + destClause :: + selectClause :: + selectDistinctClause :: + whereClause :: + groupByClause :: + orderByClause :: + sortByClause :: + clusterByClause :: + distributeByClause :: + limitClause :: + lateralViewClause :: Nil) = { + getClauses( + Seq( + "TOK_INSERT_INTO", + "TOK_DESTINATION", + "TOK_SELECT", + "TOK_SELECTDI", + "TOK_WHERE", + "TOK_GROUPBY", + "TOK_ORDERBY", + "TOK_SORTBY", + "TOK_CLUSTERBY", + "TOK_DISTRIBUTEBY", + "TOK_LIMIT", + "TOK_LATERAL_VIEW"), + singleInsert) + } + + val relations = nodeToRelation(fromClause) + val withWhere = whereClause.map { whereNode => + val Seq(whereExpr) = whereNode.getChildren.toSeq + Filter(nodeToExpr(whereExpr), relations) + }.getOrElse(relations) + + val select = + (selectClause orElse selectDistinctClause).getOrElse(sys.error("No select clause.")) + + // Script transformations are expressed as a select clause with a single expression of type + // TOK_TRANSFORM + val transformation = select.getChildren.head match { + case Token("TOK_SELEXPR", + Token("TOK_TRANSFORM", + Token("TOK_EXPLIST", inputExprs) :: + Token("TOK_SERDE", Nil) :: + Token("TOK_RECORDWRITER", writerClause) :: + // TODO: Need to support other types of (in/out)put + Token(script, Nil) :: + Token("TOK_SERDE", serdeClause) :: + Token("TOK_RECORDREADER", readerClause) :: + outputClause :: Nil) :: Nil) => + + val output = outputClause match { + case Token("TOK_ALIASLIST", aliases) => + aliases.map { case Token(name, Nil) => AttributeReference(name, StringType)() } + case Token("TOK_TABCOLLIST", attributes) => + attributes.map { case Token("TOK_TABCOL", Token(name, Nil) :: dataType :: Nil) => + AttributeReference(name, nodeToDataType(dataType))() } + } + val unescapedScript = BaseSemanticAnalyzer.unescapeSQLString(script) + + Some( + logical.ScriptTransformation( + inputExprs.map(nodeToExpr), + unescapedScript, + output, + withWhere)) + case _ => None + } + + val withLateralView = lateralViewClause.map { lv => + val Token("TOK_SELECT", + Token("TOK_SELEXPR", clauses) :: Nil) = lv.getChildren.head + + val alias = + getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText + + Generate( + nodesToGenerator(clauses), + join = true, + outer = false, + Some(alias.toLowerCase), + withWhere) + }.getOrElse(withWhere) + + + // The projection of the query can either be a normal projection, an aggregation + // (if there is a group by) or a script transformation. + val withProject = transformation.getOrElse { + // Not a transformation so must be either project or aggregation. + val selectExpressions = nameExpressions(select.getChildren.flatMap(selExprNodeToExpr)) + + groupByClause match { + case Some(groupBy) => + Aggregate(groupBy.getChildren.map(nodeToExpr), selectExpressions, withLateralView) + case None => + Project(selectExpressions, withLateralView) + } + } + + val withDistinct = + if (selectDistinctClause.isDefined) Distinct(withProject) else withProject + + val withSort = + (orderByClause, sortByClause, distributeByClause, clusterByClause) match { + case (Some(totalOrdering), None, None, None) => + Sort(totalOrdering.getChildren.map(nodeToSortOrder), withDistinct) + case (None, Some(perPartitionOrdering), None, None) => + SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder), withDistinct) + case (None, None, Some(partitionExprs), None) => + Repartition(partitionExprs.getChildren.map(nodeToExpr), withDistinct) + case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => + SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder), + Repartition(partitionExprs.getChildren.map(nodeToExpr), withDistinct)) + case (None, None, None, Some(clusterExprs)) => + SortPartitions(clusterExprs.getChildren.map(nodeToExpr).map(SortOrder(_, Ascending)), + Repartition(clusterExprs.getChildren.map(nodeToExpr), withDistinct)) + case (None, None, None, None) => withDistinct + case _ => sys.error("Unsupported set of ordering / distribution clauses.") + } + + val withLimit = + limitClause.map(l => nodeToExpr(l.getChildren.head)) + .map(StopAfter(_, withSort)) + .getOrElse(withSort) + + // TOK_INSERT_INTO means to add files to the table. + // TOK_DESTINATION means to overwrite the table. + val resultDestination = + (intoClause orElse destClause).getOrElse(sys.error("No destination found.")) + val overwrite = if (intoClause.isEmpty) true else false + nodeToDest( + resultDestination, + withLimit, + overwrite) + } + + // If there are multiple INSERTS just UNION them together into on query. + queries.reduceLeft(Union) + + case Token("TOK_UNION", left :: right :: Nil) => Union(nodeToPlan(left), nodeToPlan(right)) + + case a: ASTNode => + throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + } + + val allJoinTokens = "(TOK_.*JOIN)".r + val laterViewToken = "TOK_LATERAL_VIEW(.*)".r + def nodeToRelation(node: Node): LogicalPlan = node match { + case Token("TOK_SUBQUERY", + query :: Token(alias, Nil) :: Nil) => + Subquery(alias, nodeToPlan(query)) + + case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) => + val Token("TOK_SELECT", + Token("TOK_SELEXPR", clauses) :: Nil) = selectClause + + val alias = getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText + + Generate( + nodesToGenerator(clauses), + join = true, + outer = isOuter.nonEmpty, + Some(alias.toLowerCase), + nodeToRelation(relationClause)) + + /* All relations, possibly with aliases or sampling clauses. */ + case Token("TOK_TABREF", clauses) => + // If the last clause is not a token then it's the alias of the table. + val (nonAliasClauses, aliasClause) = + if (clauses.last.getText.startsWith("TOK")) { + (clauses, None) + } else { + (clauses.dropRight(1), Some(clauses.last)) + } + + val (Some(tableNameParts) :: + splitSampleClause :: + bucketSampleClause :: Nil) = { + getClauses(Seq("TOK_TABNAME", "TOK_TABLESPLITSAMPLE", "TOK_TABLEBUCKETSAMPLE"), + nonAliasClauses) + } + + val (db, tableName) = + tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match { + case Seq(tableOnly) => (None, tableOnly) + case Seq(databaseName, table) => (Some(databaseName), table) + } + val alias = aliasClause.map { case Token(a, Nil) => cleanIdentifier(a) } + val relation = UnresolvedRelation(db, tableName, alias) + + // Apply sampling if requested. + (bucketSampleClause orElse splitSampleClause).map { + case Token("TOK_TABLESPLITSAMPLE", + Token("TOK_ROWCOUNT", Nil) :: + Token(count, Nil) :: Nil) => + StopAfter(Literal(count.toInt), relation) + case Token("TOK_TABLESPLITSAMPLE", + Token("TOK_PERCENT", Nil) :: + Token(fraction, Nil) :: Nil) => + Sample(fraction.toDouble, withReplacement = false, (math.random * 1000).toInt, relation) + case Token("TOK_TABLEBUCKETSAMPLE", + Token(numerator, Nil) :: + Token(denominator, Nil) :: Nil) => + val fraction = numerator.toDouble / denominator.toDouble + Sample(fraction, withReplacement = false, (math.random * 1000).toInt, relation) + case a: ASTNode => + throw new NotImplementedError( + s"""No parse rules for sampling clause: ${a.getType}, text: ${a.getText} : + |${dumpTree(a).toString}" + + """.stripMargin) + }.getOrElse(relation) + + case Token("TOK_UNIQUEJOIN", joinArgs) => + val tableOrdinals = + joinArgs.zipWithIndex.filter { + case (arg, i) => arg.getText == "TOK_TABREF" + }.map(_._2) + + val isPreserved = tableOrdinals.map(i => (i - 1 < 0) || joinArgs(i - 1).getText == "PRESERVE") + val tables = tableOrdinals.map(i => nodeToRelation(joinArgs(i))) + val joinExpressions = tableOrdinals.map(i => joinArgs(i + 1).getChildren.map(nodeToExpr)) + + val joinConditions = joinExpressions.sliding(2).map { + case Seq(c1, c2) => + val predicates = (c1, c2).zipped.map { case (e1, e2) => Equals(e1, e2): Expression } + predicates.reduceLeft(And) + }.toBuffer + + val joinType = isPreserved.sliding(2).map { + case Seq(true, true) => FullOuter + case Seq(true, false) => LeftOuter + case Seq(false, true) => RightOuter + case Seq(false, false) => Inner + }.toBuffer + + val joinedTables = tables.reduceLeft(Join(_,_, Inner, None)) + + // Must be transform down. + val joinedResult = joinedTables transform { + case j: Join => + j.copy( + condition = Some(joinConditions.remove(joinConditions.length - 1)), + joinType = joinType.remove(joinType.length - 1)) + } + + val groups = (0 until joinExpressions.head.size).map(i => Coalesce(joinExpressions.map(_(i)))) + + // Unique join is not really the same as an outer join so we must group together results where + // the joinExpressions are the same, taking the First of each value is only okay because the + // user of a unique join is implicitly promising that there is only one result. + // TODO: This doesn't actually work since [[Star]] is not a valid aggregate expression. + // instead we should figure out how important supporting this feature is and whether it is + // worth the number of hacks that will be required to implement it. Namely, we need to add + // some sort of mapped star expansion that would expand all child output row to be similarly + // named output expressions where some aggregate expression has been applied (i.e. First). + ??? /// Aggregate(groups, Star(None, First(_)) :: Nil, joinedResult) + + case Token(allJoinTokens(joinToken), + relation1 :: + relation2 :: other) => + assert(other.size <= 1, s"Unhandled join child ${other}") + val joinType = joinToken match { + case "TOK_JOIN" => Inner + case "TOK_RIGHTOUTERJOIN" => RightOuter + case "TOK_LEFTOUTERJOIN" => LeftOuter + case "TOK_FULLOUTERJOIN" => FullOuter + } + assert(other.size <= 1, "Unhandled join clauses.") + Join(nodeToRelation(relation1), + nodeToRelation(relation2), + joinType, + other.headOption.map(nodeToExpr)) + + case a: ASTNode => + throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + } + + def nodeToSortOrder(node: Node): SortOrder = node match { + case Token("TOK_TABSORTCOLNAMEASC", sortExpr :: Nil) => + SortOrder(nodeToExpr(sortExpr), Ascending) + case Token("TOK_TABSORTCOLNAMEDESC", sortExpr :: Nil) => + SortOrder(nodeToExpr(sortExpr), Descending) + + case a: ASTNode => + throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + } + + val destinationToken = "TOK_DESTINATION|TOK_INSERT_INTO".r + protected def nodeToDest( + node: Node, + query: LogicalPlan, + overwrite: Boolean): LogicalPlan = node match { + case Token(destinationToken(), + Token("TOK_DIR", + Token("TOK_TMP_FILE", Nil) :: Nil) :: Nil) => + query + + case Token(destinationToken(), + Token("TOK_TAB", + tableArgs) :: Nil) => + val Some(tableNameParts) :: partitionClause :: Nil = + getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) + + val (db, tableName) = + tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match { + case Seq(tableOnly) => (None, tableOnly) + case Seq(databaseName, table) => (Some(databaseName), table) + } + + val partitionKeys = partitionClause.map(_.getChildren.map { + // Parse partitions. We also make keys case insensitive. + case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => + cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value)) + case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => + cleanIdentifier(key.toLowerCase) -> None + }.toMap).getOrElse(Map.empty) + + if (partitionKeys.values.exists(p => p.isEmpty)) { + throw new NotImplementedError(s"Do not support INSERT INTO/OVERWRITE with" + + s"dynamic partitioning.") + } + + InsertIntoTable(UnresolvedRelation(db, tableName, None), partitionKeys, query, overwrite) + + case a: ASTNode => + throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + } + + protected def selExprNodeToExpr(node: Node): Option[Expression] = node match { + case Token("TOK_SELEXPR", + e :: Nil) => + Some(nodeToExpr(e)) + + case Token("TOK_SELEXPR", + e :: Token(alias, Nil) :: Nil) => + Some(Alias(nodeToExpr(e), alias)()) + + /* Hints are ignored */ + case Token("TOK_HINTLIST", _) => None + + case a: ASTNode => + throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + } + + + protected val escapedIdentifier = "`([^`]+)`".r + /** Strips backticks from ident if present */ + protected def cleanIdentifier(ident: String): String = ident match { + case escapedIdentifier(i) => i + case plainIdent => plainIdent + } + + val numericAstTypes = Seq( + HiveParser.Number, + HiveParser.TinyintLiteral, + HiveParser.SmallintLiteral, + HiveParser.BigintLiteral) + + /* Case insensitive matches */ + val COUNT = "(?i)COUNT".r + val AVG = "(?i)AVG".r + val SUM = "(?i)SUM".r + val RAND = "(?i)RAND".r + val AND = "(?i)AND".r + val OR = "(?i)OR".r + val NOT = "(?i)NOT".r + val TRUE = "(?i)TRUE".r + val FALSE = "(?i)FALSE".r + + protected def nodeToExpr(node: Node): Expression = node match { + /* Attribute References */ + case Token("TOK_TABLE_OR_COL", + Token(name, Nil) :: Nil) => + UnresolvedAttribute(cleanIdentifier(name)) + case Token(".", qualifier :: Token(attr, Nil) :: Nil) => + nodeToExpr(qualifier) match { + case UnresolvedAttribute(qualifierName) => + UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr)) + // The precidence for . seems to be wrong, so [] binds tighter an we need to go inside to + // find the underlying attribute references. + case GetItem(UnresolvedAttribute(qualifierName), ordinal) => + GetItem(UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr)), ordinal) + } + + /* Stars (*) */ + case Token("TOK_ALLCOLREF", Nil) => Star(None) + // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only + // has a single child which is tableName. + case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) => + Star(Some(name)) + + /* Aggregate Functions */ + case Token("TOK_FUNCTION", Token(AVG(), Nil) :: arg :: Nil) => Average(nodeToExpr(arg)) + case Token("TOK_FUNCTION", Token(COUNT(), Nil) :: arg :: Nil) => Count(nodeToExpr(arg)) + case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => Count(Literal(1)) + case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => CountDistinct(args.map(nodeToExpr)) + case Token("TOK_FUNCTION", Token(SUM(), Nil) :: arg :: Nil) => Sum(nodeToExpr(arg)) + case Token("TOK_FUNCTIONDI", Token(SUM(), Nil) :: arg :: Nil) => SumDistinct(nodeToExpr(arg)) + + /* Casts */ + case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), StringType) + case Token("TOK_FUNCTION", Token("TOK_VARCHAR", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), StringType) + case Token("TOK_FUNCTION", Token("TOK_INT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), IntegerType) + case Token("TOK_FUNCTION", Token("TOK_BIGINT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), LongType) + case Token("TOK_FUNCTION", Token("TOK_FLOAT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), FloatType) + case Token("TOK_FUNCTION", Token("TOK_DOUBLE", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DoubleType) + case Token("TOK_FUNCTION", Token("TOK_SMALLINT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), ShortType) + case Token("TOK_FUNCTION", Token("TOK_TINYINT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), ByteType) + case Token("TOK_FUNCTION", Token("TOK_BINARY", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), BinaryType) + case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), BooleanType) + case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DecimalType) + + /* Arithmetic */ + case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child)) + case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right)) + case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right)) + case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right)) + case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right)) + case Token("DIV", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right)) + case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right)) + + /* Comparisons */ + case Token("=", left :: right:: Nil) => Equals(nodeToExpr(left), nodeToExpr(right)) + case Token("!=", left :: right:: Nil) => Not(Equals(nodeToExpr(left), nodeToExpr(right))) + case Token("<>", left :: right:: Nil) => Not(Equals(nodeToExpr(left), nodeToExpr(right))) + case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right)) + case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right)) + case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right)) + case Token("<=", left :: right:: Nil) => LessThanOrEqual(nodeToExpr(left), nodeToExpr(right)) + case Token("LIKE", left :: right:: Nil) => + UnresolvedFunction("LIKE", Seq(nodeToExpr(left), nodeToExpr(right))) + case Token("RLIKE", left :: right:: Nil) => + UnresolvedFunction("RLIKE", Seq(nodeToExpr(left), nodeToExpr(right))) + case Token("REGEXP", left :: right:: Nil) => + UnresolvedFunction("REGEXP", Seq(nodeToExpr(left), nodeToExpr(right))) + case Token("TOK_FUNCTION", Token("TOK_ISNOTNULL", Nil) :: child :: Nil) => + IsNotNull(nodeToExpr(child)) + case Token("TOK_FUNCTION", Token("TOK_ISNULL", Nil) :: child :: Nil) => + IsNull(nodeToExpr(child)) + case Token("TOK_FUNCTION", Token("IN", Nil) :: value :: list) => + In(nodeToExpr(value), list.map(nodeToExpr)) + + /* Boolean Logic */ + case Token(AND(), left :: right:: Nil) => And(nodeToExpr(left), nodeToExpr(right)) + case Token(OR(), left :: right:: Nil) => Or(nodeToExpr(left), nodeToExpr(right)) + case Token(NOT(), child :: Nil) => Not(nodeToExpr(child)) + + /* Complex datatype manipulation */ + case Token("[", child :: ordinal :: Nil) => + GetItem(nodeToExpr(child), nodeToExpr(ordinal)) + + /* Other functions */ + case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand + + /* UDFs - Must be last otherwise will preempt built in functions */ + case Token("TOK_FUNCTION", Token(name, Nil) :: args) => + UnresolvedFunction(name, args.map(nodeToExpr)) + case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => + UnresolvedFunction(name, Star(None) :: Nil) + + /* Literals */ + case Token("TOK_NULL", Nil) => Literal(null, NullType) + case Token(TRUE(), Nil) => Literal(true, BooleanType) + case Token(FALSE(), Nil) => Literal(false, BooleanType) + case Token("TOK_STRINGLITERALSEQUENCE", strings) => + Literal(strings.map(s => BaseSemanticAnalyzer.unescapeSQLString(s.getText)).mkString) + + // This code is adapted from + // /ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java#L223 + case ast: ASTNode if numericAstTypes contains ast.getType => + var v: Literal = null + try { + if (ast.getText.endsWith("L")) { + // Literal bigint. + v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toLong, LongType) + } else if (ast.getText.endsWith("S")) { + // Literal smallint. + v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toShort, ShortType) + } else if (ast.getText.endsWith("Y")) { + // Literal tinyint. + v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toByte, ByteType) + } else if (ast.getText.endsWith("BD")) { + // Literal decimal + val strVal = ast.getText.substring(0, ast.getText.length() - 2) + BigDecimal(strVal) + } else { + v = Literal(ast.getText.toDouble, DoubleType) + v = Literal(ast.getText.toLong, LongType) + v = Literal(ast.getText.toInt, IntegerType) + } + } catch { + case nfe: NumberFormatException => // Do nothing + } + + if (v == null) { + sys.error(s"Failed to parse number ${ast.getText}") + } else { + v + } + + case ast: ASTNode if ast.getType == HiveParser.StringLiteral => + Literal(BaseSemanticAnalyzer.unescapeSQLString(ast.getText)) + + case a: ASTNode => + throw new NotImplementedError( + s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText} : + |${dumpTree(a).toString}" + + """.stripMargin) + } + + + val explode = "(?i)explode".r + def nodesToGenerator(nodes: Seq[Node]): Generator = { + val function = nodes.head + + val attributes = nodes.flatMap { + case Token(a, Nil) => a.toLowerCase :: Nil + case _ => Nil + } + + function match { + case Token("TOK_FUNCTION", Token(explode(), Nil) :: child :: Nil) => + Explode(attributes, nodeToExpr(child)) + + case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => + HiveGenericUdtf(functionName, attributes, children.map(nodeToExpr)) + + case a: ASTNode => + throw new NotImplementedError( + s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText}, tree: + |${dumpTree(a).toString} + """.stripMargin) + } + } + + def dumpTree(node: Node, builder: StringBuilder = new StringBuilder, indent: Int = 0) + : StringBuilder = { + node match { + case a: ASTNode => builder.append((" " * indent) + a.getText + "\n") + case other => sys.error(s"Non ASTNode encountered: $other") + } + + Option(node.getChildren).map(_.toList).getOrElse(Nil).foreach(dumpTree(_, builder, indent + 1)) + builder + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala new file mode 100644 index 0000000000..92d84208ab --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive + +import catalyst.expressions._ +import catalyst.planning._ +import catalyst.plans._ +import catalyst.plans.logical.{BaseRelation, LogicalPlan} + +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.parquet.{ParquetRelation, InsertIntoParquetTable, ParquetTableScan} + +trait HiveStrategies { + // Possibly being too clever with types here... or not clever enough. + self: SQLContext#SparkPlanner => + + val hiveContext: HiveContext + + object Scripts extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.ScriptTransformation(input, script, output, child) => + ScriptTransformation(input, script, output, planLater(child))(hiveContext) :: Nil + case _ => Nil + } + } + + object DataSinks extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) => + InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil + case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => + InsertIntoParquetTable(table, planLater(child))(hiveContext.sparkContext) :: Nil + case _ => Nil + } + } + + object HiveTableScans extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + // Push attributes into table scan when possible. + case p @ logical.Project(projectList, m: MetastoreRelation) if isSimpleProject(projectList) => + HiveTableScan(projectList.asInstanceOf[Seq[Attribute]], m, None)(hiveContext) :: Nil + case m: MetastoreRelation => + HiveTableScan(m.output, m, None)(hiveContext) :: Nil + case _ => Nil + } + } + + /** + * A strategy used to detect filtering predicates on top of a partitioned relation to help + * partition pruning. + * + * This strategy itself doesn't perform partition pruning, it just collects and combines all the + * partition pruning predicates and pass them down to the underlying [[HiveTableScan]] operator, + * which does the actual pruning work. + */ + object PartitionPrunings extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case p @ FilteredOperation(predicates, relation: MetastoreRelation) + if relation.isPartitioned => + + val partitionKeyIds = relation.partitionKeys.map(_.id).toSet + + // Filter out all predicates that only deal with partition keys + val (pruningPredicates, otherPredicates) = predicates.partition { + _.references.map(_.id).subsetOf(partitionKeyIds) + } + + val scan = HiveTableScan( + relation.output, relation, pruningPredicates.reduceLeftOption(And))(hiveContext) + + otherPredicates + .reduceLeftOption(And) + .map(Filter(_, scan)) + .getOrElse(scan) :: Nil + + case _ => + Nil + } + } + + /** + * A strategy that detects projects and filters over some relation and applies column pruning if + * possible. Partition pruning is applied first if the relation is partitioned. + */ + object ColumnPrunings extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + // TODO(andre): the current mix of HiveRelation and ParquetRelation + // here appears artificial; try to refactor to break it into two + case PhysicalOperation(projectList, predicates, relation: BaseRelation) => + val predicateOpt = predicates.reduceOption(And) + val predicateRefs = predicateOpt.map(_.references).getOrElse(Set.empty) + val projectRefs = projectList.flatMap(_.references) + + // To figure out what columns to preserve after column pruning, we need to consider: + // + // 1. Columns referenced by the project list (order preserved) + // 2. Columns referenced by filtering predicates but not by project list + // 3. Relation output + // + // Then the final result is ((1 union 2) intersect 3) + val prunedCols = (projectRefs ++ (predicateRefs -- projectRefs)).intersect(relation.output) + + val filteredScans = + if (relation.isPartitioned) { // from here on relation must be a [[MetaStoreRelation]] + // Applies partition pruning first for partitioned table + val filteredRelation = predicateOpt.map(logical.Filter(_, relation)).getOrElse(relation) + PartitionPrunings(filteredRelation).view.map(_.transform { + case scan: HiveTableScan => + scan.copy(attributes = prunedCols)(hiveContext) + }) + } else { + val scan = relation match { + case MetastoreRelation(_, _, _) => { + HiveTableScan( + prunedCols, + relation.asInstanceOf[MetastoreRelation], + None)(hiveContext) + } + case ParquetRelation(_, _) => { + ParquetTableScan( + relation.output, + relation.asInstanceOf[ParquetRelation], + None)(hiveContext.sparkContext) + .pruneColumns(prunedCols) + } + } + predicateOpt.map(execution.Filter(_, scan)).getOrElse(scan) :: Nil + } + + if (isSimpleProject(projectList) && prunedCols == projectRefs) { + filteredScans + } else { + filteredScans.view.map(execution.Project(projectList, _)) + } + + case _ => + Nil + } + } + + /** + * Returns true if `projectList` only performs column pruning and does not evaluate other + * complex expressions. + */ + def isSimpleProject(projectList: Seq[NamedExpression]) = { + projectList.forall(_.isInstanceOf[Attribute]) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ScriptTransformation.scala new file mode 100644 index 0000000000..f20e9d4de4 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ScriptTransformation.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive + +import java.io.{InputStreamReader, BufferedReader} + +import catalyst.expressions._ +import org.apache.spark.sql.execution._ + +import scala.collection.JavaConversions._ + +/** + * Transforms the input by forking and running the specified script. + * + * @param input the set of expression that should be passed to the script. + * @param script the command that should be executed. + * @param output the attributes that are produced by the script. + */ +case class ScriptTransformation( + input: Seq[Expression], + script: String, + output: Seq[Attribute], + child: SparkPlan)(@transient sc: HiveContext) + extends UnaryNode { + + override def otherCopyArgs = sc :: Nil + + def execute() = { + child.execute().mapPartitions { iter => + val cmd = List("/bin/bash", "-c", script) + val builder = new ProcessBuilder(cmd) + val proc = builder.start() + val inputStream = proc.getInputStream + val outputStream = proc.getOutputStream + val reader = new BufferedReader(new InputStreamReader(inputStream)) + + // TODO: This should be exposed as an iterator instead of reading in all the data at once. + val outputLines = collection.mutable.ArrayBuffer[Row]() + val readerThread = new Thread("Transform OutputReader") { + override def run() { + var curLine = reader.readLine() + while (curLine != null) { + // TODO: Use SerDe + outputLines += new GenericRow(curLine.split("\t").asInstanceOf[Array[Any]]) + curLine = reader.readLine() + } + } + } + readerThread.start() + val outputProjection = new Projection(input) + iter + .map(outputProjection) + // TODO: Use SerDe + .map(_.mkString("", "\t", "\n").getBytes).foreach(outputStream.write) + outputStream.close() + readerThread.join() + outputLines.toIterator + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala new file mode 100644 index 0000000000..71d751cbc4 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants._ +import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} +import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.hadoop.hive.serde2.Deserializer +import org.apache.hadoop.hive.ql.exec.Utilities +import org.apache.hadoop.io.Writable +import org.apache.hadoop.fs.{Path, PathFilter} +import org.apache.hadoop.mapred.{FileInputFormat, JobConf, InputFormat} + +import org.apache.spark.SerializableWritable +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.{HadoopRDD, UnionRDD, EmptyRDD, RDD} + + +/** + * A trait for subclasses that handle table scans. + */ +private[hive] sealed trait TableReader { + def makeRDDForTable(hiveTable: HiveTable): RDD[_] + + def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[_] + +} + + +/** + * Helper class for scanning tables stored in Hadoop - e.g., to read Hive tables that reside in the + * data warehouse directory. + */ +private[hive] +class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveContext) + extends TableReader { + + // Choose the minimum number of splits. If mapred.map.tasks is set, then use that unless + // it is smaller than what Spark suggests. + private val _minSplitsPerRDD = math.max( + sc.hiveconf.getInt("mapred.map.tasks", 1), sc.sparkContext.defaultMinSplits) + + + // TODO: set aws s3 credentials. + + private val _broadcastedHiveConf = + sc.sparkContext.broadcast(new SerializableWritable(sc.hiveconf)) + + def broadcastedHiveConf = _broadcastedHiveConf + + def hiveConf = _broadcastedHiveConf.value.value + + override def makeRDDForTable(hiveTable: HiveTable): RDD[_] = + makeRDDForTable( + hiveTable, + _tableDesc.getDeserializerClass.asInstanceOf[Class[Deserializer]], + filterOpt = None) + + /** + * Creates a Hadoop RDD to read data from the target table's data directory. Returns a transformed + * RDD that contains deserialized rows. + * + * @param hiveTable Hive metadata for the table being scanned. + * @param deserializerClass Class of the SerDe used to deserialize Writables read from Hadoop. + * @param filterOpt If defined, then the filter is used to reject files contained in the data + * directory being read. If None, then all files are accepted. + */ + def makeRDDForTable( + hiveTable: HiveTable, + deserializerClass: Class[_ <: Deserializer], + filterOpt: Option[PathFilter]): RDD[_] = + { + assert(!hiveTable.isPartitioned, """makeRDDForTable() cannot be called on a partitioned table, + since input formats may differ across partitions. Use makeRDDForTablePartitions() instead.""") + + // Create local references to member variables, so that the entire `this` object won't be + // serialized in the closure below. + val tableDesc = _tableDesc + val broadcastedHiveConf = _broadcastedHiveConf + + val tablePath = hiveTable.getPath + val inputPathStr = applyFilterIfNeeded(tablePath, filterOpt) + + //logDebug("Table input: %s".format(tablePath)) + val ifc = hiveTable.getInputFormatClass + .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] + val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) + + val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter => + val hconf = broadcastedHiveConf.value.value + val deserializer = deserializerClass.newInstance() + deserializer.initialize(hconf, tableDesc.getProperties) + + // Deserialize each Writable to get the row value. + iter.map { + case v: Writable => deserializer.deserialize(v) + case value => + sys.error(s"Unable to deserialize non-Writable: $value of ${value.getClass.getName}") + } + } + deserializedHadoopRDD + } + + override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[_] = { + val partitionToDeserializer = partitions.map(part => + (part, part.getDeserializer.getClass.asInstanceOf[Class[Deserializer]])).toMap + makeRDDForPartitionedTable(partitionToDeserializer, filterOpt = None) + } + + /** + * Create a HadoopRDD for every partition key specified in the query. Note that for on-disk Hive + * tables, a data directory is created for each partition corresponding to keys specified using + * 'PARTITION BY'. + * + * @param partitionToDeserializer Mapping from a Hive Partition metadata object to the SerDe + * class to use to deserialize input Writables from the corresponding partition. + * @param filterOpt If defined, then the filter is used to reject files contained in the data + * subdirectory of each partition being read. If None, then all files are accepted. + */ + def makeRDDForPartitionedTable( + partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]], + filterOpt: Option[PathFilter]): RDD[_] = + { + val hivePartitionRDDs = partitionToDeserializer.map { case (partition, partDeserializer) => + val partDesc = Utilities.getPartitionDesc(partition) + val partPath = partition.getPartitionPath + val inputPathStr = applyFilterIfNeeded(partPath, filterOpt) + val ifc = partDesc.getInputFileFormatClass + .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] + // Get partition field info + val partSpec = partDesc.getPartSpec + val partProps = partDesc.getProperties + + val partColsDelimited: String = partProps.getProperty(META_TABLE_PARTITION_COLUMNS) + // Partitioning columns are delimited by "/" + val partCols = partColsDelimited.trim().split("/").toSeq + // 'partValues[i]' contains the value for the partitioning column at 'partCols[i]'. + val partValues = if (partSpec == null) { + Array.fill(partCols.size)(new String) + } else { + partCols.map(col => new String(partSpec.get(col))).toArray + } + + // Create local references so that the outer object isn't serialized. + val tableDesc = _tableDesc + val broadcastedHiveConf = _broadcastedHiveConf + val localDeserializer = partDeserializer + + val hivePartitionRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) + hivePartitionRDD.mapPartitions { iter => + val hconf = broadcastedHiveConf.value.value + val rowWithPartArr = new Array[Object](2) + // Map each tuple to a row object + iter.map { value => + val deserializer = localDeserializer.newInstance() + deserializer.initialize(hconf, partProps) + val deserializedRow = deserializer.deserialize(value) + rowWithPartArr.update(0, deserializedRow) + rowWithPartArr.update(1, partValues) + rowWithPartArr.asInstanceOf[Object] + } + } + }.toSeq + // Even if we don't use any partitions, we still need an empty RDD + if (hivePartitionRDDs.size == 0) { + new EmptyRDD[Object](sc.sparkContext) + } else { + new UnionRDD(hivePartitionRDDs(0).context, hivePartitionRDDs) + } + } + + /** + * If `filterOpt` is defined, then it will be used to filter files from `path`. These files are + * returned in a single, comma-separated string. + */ + private def applyFilterIfNeeded(path: Path, filterOpt: Option[PathFilter]): String = { + filterOpt match { + case Some(filter) => + val fs = path.getFileSystem(sc.hiveconf) + val filteredFiles = fs.listStatus(path, filter).map(_.getPath.toString) + filteredFiles.mkString(",") + case None => path.toString + } + } + + /** + * Creates a HadoopRDD based on the broadcasted HiveConf and other job properties that will be + * applied locally on each slave. + */ + private def createHadoopRdd( + tableDesc: TableDesc, + path: String, + inputFormatClass: Class[InputFormat[Writable, Writable]]) + : RDD[Writable] = { + val initializeJobConfFunc = HadoopTableReader.initializeLocalJobConfFunc(path, tableDesc) _ + + val rdd = new HadoopRDD( + sc.sparkContext, + _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableWritable[Configuration]]], + Some(initializeJobConfFunc), + inputFormatClass, + classOf[Writable], + classOf[Writable], + _minSplitsPerRDD) + + // Only take the value (skip the key) because Hive works only with values. + rdd.map(_._2) + } + +} + +private[hive] object HadoopTableReader { + + /** + * Curried. After given an argument for 'path', the resulting JobConf => Unit closure is used to + * instantiate a HadoopRDD. + */ + def initializeLocalJobConfFunc(path: String, tableDesc: TableDesc)(jobConf: JobConf) { + FileInputFormat.setInputPaths(jobConf, path) + if (tableDesc != null) { + Utilities.copyTableJobPropertiesToConf(tableDesc, jobConf) + } + val bufferSize = System.getProperty("spark.buffer.size", "65536") + jobConf.set("io.file.buffer.size", bufferSize) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala new file mode 100644 index 0000000000..17ae4ef63c --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -0,0 +1,341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive + +import java.io.File +import java.util.{Set => JavaSet} + +import scala.collection.mutable +import scala.collection.JavaConversions._ +import scala.language.implicitConversions + +import org.apache.hadoop.hive.metastore.api.{SerDeInfo, StorageDescriptor} +import org.apache.hadoop.hive.metastore.MetaStoreUtils +import org.apache.hadoop.hive.ql.exec.FunctionRegistry +import org.apache.hadoop.hive.ql.io.avro.{AvroContainerOutputFormat, AvroContainerInputFormat} +import org.apache.hadoop.hive.ql.metadata.Table +import org.apache.hadoop.hive.serde2.avro.AvroSerDe +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe +import org.apache.hadoop.hive.serde2.RegexSerDe + +import org.apache.spark.{SparkContext, SparkConf} + +import catalyst.analysis._ +import catalyst.plans.logical.{LogicalPlan, NativeCommand} +import catalyst.util._ + +object TestHive + extends TestHiveContext(new SparkContext("local", "TestSQLContext", new SparkConf())) + +/** + * A locally running test instance of Spark's Hive execution engine. + * + * Data from [[testTables]] will be automatically loaded whenever a query is run over those tables. + * Calling [[reset]] will delete all tables and other state in the database, leaving the database + * in a "clean" state. + * + * TestHive is singleton object version of this class because instantiating multiple copies of the + * hive metastore seems to lead to weird non-deterministic failures. Therefore, the execution of + * testcases that rely on TestHive must be serialized. + */ +class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { + self => + + // By clearing the port we force Spark to pick a new one. This allows us to rerun tests + // without restarting the JVM. + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") + + override lazy val warehousePath = getTempFilePath("sparkHiveWarehouse").getCanonicalPath + override lazy val metastorePath = getTempFilePath("sparkHiveMetastore").getCanonicalPath + + /** The location of the compiled hive distribution */ + lazy val hiveHome = envVarToFile("HIVE_HOME") + /** The location of the hive source code. */ + lazy val hiveDevHome = envVarToFile("HIVE_DEV_HOME") + + // Override so we can intercept relative paths and rewrite them to point at hive. + override def runSqlHive(sql: String): Seq[String] = super.runSqlHive(rewritePaths(sql)) + + override def executePlan(plan: LogicalPlan): this.QueryExecution = + new this.QueryExecution { val logical = plan } + + /** + * Returns the value of specified environmental variable as a [[java.io.File]] after checking + * to ensure it exists + */ + private def envVarToFile(envVar: String): Option[File] = { + Option(System.getenv(envVar)).map(new File(_)) + } + + /** + * Replaces relative paths to the parent directory "../" with hiveDevHome since this is how the + * hive test cases assume the system is set up. + */ + private def rewritePaths(cmd: String): String = + if (cmd.toUpperCase contains "LOAD DATA") { + val testDataLocation = + hiveDevHome.map(_.getCanonicalPath).getOrElse(inRepoTests.getCanonicalPath) + cmd.replaceAll("\\.\\.", testDataLocation) + } else { + cmd + } + + val hiveFilesTemp = File.createTempFile("catalystHiveFiles", "") + hiveFilesTemp.delete() + hiveFilesTemp.mkdir() + + val inRepoTests = new File("src/test/resources/") + def getHiveFile(path: String): File = { + val stripped = path.replaceAll("""\.\.\/""", "") + hiveDevHome + .map(new File(_, stripped)) + .filter(_.exists) + .getOrElse(new File(inRepoTests, stripped)) + } + + val describedTable = "DESCRIBE (\\w+)".r + + class SqlQueryExecution(sql: String) extends this.QueryExecution { + lazy val logical = HiveQl.parseSql(sql) + def hiveExec() = runSqlHive(sql) + override def toString = sql + "\n" + super.toString + } + + /** + * Override QueryExecution with special debug workflow. + */ + abstract class QueryExecution extends super.QueryExecution { + override lazy val analyzed = { + val describedTables = logical match { + case NativeCommand(describedTable(tbl)) => tbl :: Nil + case _ => Nil + } + + // Make sure any test tables referenced are loaded. + val referencedTables = + describedTables ++ + logical.collect { case UnresolvedRelation(databaseName, name, _) => name } + val referencedTestTables = referencedTables.filter(testTables.contains) + logger.debug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") + referencedTestTables.foreach(loadTestTable) + // Proceed with analysis. + analyzer(logical) + } + } + + case class TestTable(name: String, commands: (()=>Unit)*) + + implicit class SqlCmd(sql: String) { + def cmd = () => new SqlQueryExecution(sql).stringResult(): Unit + } + + /** + * A list of test tables and the DDL required to initialize them. A test table is loaded on + * demand when a query are run against it. + */ + lazy val testTables = new mutable.HashMap[String, TestTable]() + def registerTestTable(testTable: TestTable) = testTables += (testTable.name -> testTable) + + // The test tables that are defined in the Hive QTestUtil. + // /itests/util/src/main/java/org/apache/hadoop/hive/ql/QTestUtil.java + val hiveQTestUtilTables = Seq( + TestTable("src", + "CREATE TABLE src (key INT, value STRING)".cmd, + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd), + TestTable("src1", + "CREATE TABLE src1 (key INT, value STRING)".cmd, + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), + TestTable("dest1", + "CREATE TABLE IF NOT EXISTS dest1 (key INT, value STRING)".cmd), + TestTable("dest2", + "CREATE TABLE IF NOT EXISTS dest2 (key INT, value STRING)".cmd), + TestTable("dest3", + "CREATE TABLE IF NOT EXISTS dest3 (key INT, value STRING)".cmd), + TestTable("srcpart", () => { + runSqlHive( + "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)") + for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { + runSqlHive( + s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' + |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr') + """.stripMargin) + } + }), + TestTable("srcpart1", () => { + runSqlHive("CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)") + for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) { + runSqlHive( + s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' + |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr') + """.stripMargin) + } + }), + TestTable("src_thrift", () => { + import org.apache.thrift.protocol.TBinaryProtocol + import org.apache.hadoop.hive.serde2.thrift.test.Complex + import org.apache.hadoop.hive.serde2.thrift.ThriftDeserializer + import org.apache.hadoop.mapred.SequenceFileInputFormat + import org.apache.hadoop.mapred.SequenceFileOutputFormat + + val srcThrift = new Table("default", "src_thrift") + srcThrift.setFields(Nil) + srcThrift.setInputFormatClass(classOf[SequenceFileInputFormat[_,_]].getName) + // In Hive, SequenceFileOutputFormat will be substituted by HiveSequenceFileOutputFormat. + srcThrift.setOutputFormatClass(classOf[SequenceFileOutputFormat[_,_]].getName) + srcThrift.setSerializationLib(classOf[ThriftDeserializer].getName) + srcThrift.setSerdeParam("serialization.class", classOf[Complex].getName) + srcThrift.setSerdeParam("serialization.format", classOf[TBinaryProtocol].getName) + catalog.client.createTable(srcThrift) + + + runSqlHive( + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/complex.seq")}' INTO TABLE src_thrift") + }), + TestTable("serdeins", + s"""CREATE TABLE serdeins (key INT, value STRING) + |ROW FORMAT SERDE '${classOf[LazySimpleSerDe].getCanonicalName}' + |WITH SERDEPROPERTIES ('field.delim'='\\t') + """.stripMargin.cmd, + "INSERT OVERWRITE TABLE serdeins SELECT * FROM src".cmd), + TestTable("sales", + s"""CREATE TABLE IF NOT EXISTS sales (key STRING, value INT) + |ROW FORMAT SERDE '${classOf[RegexSerDe].getCanonicalName}' + |WITH SERDEPROPERTIES ("input.regex" = "([^ ]*)\t([^ ]*)") + """.stripMargin.cmd, + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/sales.txt")}' INTO TABLE sales".cmd), + TestTable("episodes", + s"""CREATE TABLE episodes (title STRING, air_date STRING, doctor INT) + |ROW FORMAT SERDE '${classOf[AvroSerDe].getCanonicalName}' + |STORED AS + |INPUTFORMAT '${classOf[AvroContainerInputFormat].getCanonicalName}' + |OUTPUTFORMAT '${classOf[AvroContainerOutputFormat].getCanonicalName}' + |TBLPROPERTIES ( + | 'avro.schema.literal'='{ + | "type": "record", + | "name": "episodes", + | "namespace": "testing.hive.avro.serde", + | "fields": [ + | { + | "name": "title", + | "type": "string", + | "doc": "episode title" + | }, + | { + | "name": "air_date", + | "type": "string", + | "doc": "initial date" + | }, + | { + | "name": "doctor", + | "type": "int", + | "doc": "main actor playing the Doctor in episode" + | } + | ] + | }' + |) + """.stripMargin.cmd, + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/episodes.avro")}' INTO TABLE episodes".cmd + ) + ) + + hiveQTestUtilTables.foreach(registerTestTable) + + private val loadedTables = new collection.mutable.HashSet[String] + + def loadTestTable(name: String) { + if (!(loadedTables contains name)) { + // Marks the table as loaded first to prevent infite mutually recursive table loading. + loadedTables += name + logger.info(s"Loading test table $name") + val createCmds = + testTables.get(name).map(_.commands).getOrElse(sys.error(s"Unknown test table $name")) + createCmds.foreach(_()) + } + } + + /** + * Records the UDFs present when the server starts, so we can delete ones that are created by + * tests. + */ + protected val originalUdfs: JavaSet[String] = FunctionRegistry.getFunctionNames + + /** + * Resets the test instance by deleting any tables that have been created. + * TODO: also clear out UDFs, views, etc. + */ + def reset() { + try { + // HACK: Hive is too noisy by default. + org.apache.log4j.LogManager.getCurrentLoggers.foreach { logger => + logger.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) + } + + // It is important that we RESET first as broken hooks that might have been set could break + // other sql exec here. + runSqlHive("RESET") + // For some reason, RESET does not reset the following variables... + runSqlHive("set datanucleus.cache.collections=true") + runSqlHive("set datanucleus.cache.collections.lazy=true") + // Lots of tests fail if we do not change the partition whitelist from the default. + runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") + + loadedTables.clear() + catalog.client.getAllTables("default").foreach { t => + logger.debug(s"Deleting table $t") + val table = catalog.client.getTable("default", t) + + catalog.client.getIndexes("default", t, 255).foreach { index => + catalog.client.dropIndex("default", t, index.getIndexName, true) + } + + if (!table.isIndexTable) { + catalog.client.dropTable("default", t) + } + } + + catalog.client.getAllDatabases.filterNot(_ == "default").foreach { db => + logger.debug(s"Dropping Database: $db") + catalog.client.dropDatabase(db, true, false, true) + } + + FunctionRegistry.getFunctionNames.filterNot(originalUdfs.contains(_)).foreach { udfName => + FunctionRegistry.unregisterTemporaryUDF(udfName) + } + + configure() + + runSqlHive("USE default") + + // Just loading src makes a lot of tests pass. This is because some tests do something like + // drop an index on src at the beginning. Since we just pass DDL to hive this bypasses our + // Analyzer and thus the test table auto-loading mechanism. + // Remove after we handle more DDL operations natively. + loadTestTable("src") + loadTestTable("srcpart") + } catch { + case e: Exception => + logger.error(s"FATAL ERROR: Failed to reset TestDB state. $e") + // At this point there is really no reason to continue, but the test framework traps exits. + // So instead we just pause forever so that at least the developer can see where things + // started to go wrong. + Thread.sleep(100000) + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala new file mode 100644 index 0000000000..d20fd87f34 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala @@ -0,0 +1,356 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive + +import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} +import org.apache.hadoop.hive.metastore.MetaStoreUtils +import org.apache.hadoop.hive.ql.Context +import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Hive} +import org.apache.hadoop.hive.ql.plan.{TableDesc, FileSinkDesc} +import org.apache.hadoop.hive.serde2.Serializer +import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption +import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveDecimalObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveVarcharObjectInspector + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapred._ + +import catalyst.expressions._ +import catalyst.types.{BooleanType, DataType} +import org.apache.spark.{TaskContext, SparkException} +import catalyst.expressions.Cast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution._ + +import scala.Some +import scala.collection.immutable.ListMap + +/* Implicits */ +import scala.collection.JavaConversions._ + +/** + * The Hive table scan operator. Column and partition pruning are both handled. + * + * @constructor + * @param attributes Attributes to be fetched from the Hive table. + * @param relation The Hive table be be scanned. + * @param partitionPruningPred An optional partition pruning predicate for partitioned table. + */ +case class HiveTableScan( + attributes: Seq[Attribute], + relation: MetastoreRelation, + partitionPruningPred: Option[Expression])( + @transient val sc: HiveContext) + extends LeafNode + with HiveInspectors { + + require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned, + "Partition pruning predicates only supported for partitioned tables.") + + // Bind all partition key attribute references in the partition pruning predicate for later + // evaluation. + private val boundPruningPred = partitionPruningPred.map { pred => + require( + pred.dataType == BooleanType, + s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") + + BindReferences.bindReference(pred, relation.partitionKeys) + } + + @transient + val hadoopReader = new HadoopTableReader(relation.tableDesc, sc) + + /** + * The hive object inspector for this table, which can be used to extract values from the + * serialized row representation. + */ + @transient + lazy val objectInspector = + relation.tableDesc.getDeserializer.getObjectInspector.asInstanceOf[StructObjectInspector] + + /** + * Functions that extract the requested attributes from the hive output. Partitioned values are + * casted from string to its declared data type. + */ + @transient + protected lazy val attributeFunctions: Seq[(Any, Array[String]) => Any] = { + attributes.map { a => + val ordinal = relation.partitionKeys.indexOf(a) + if (ordinal >= 0) { + (_: Any, partitionKeys: Array[String]) => { + val value = partitionKeys(ordinal) + val dataType = relation.partitionKeys(ordinal).dataType + castFromString(value, dataType) + } + } else { + val ref = objectInspector.getAllStructFieldRefs + .find(_.getFieldName == a.name) + .getOrElse(sys.error(s"Can't find attribute $a")) + (row: Any, _: Array[String]) => { + val data = objectInspector.getStructFieldData(row, ref) + unwrapData(data, ref.getFieldObjectInspector) + } + } + } + } + + private def castFromString(value: String, dataType: DataType) = { + Cast(Literal(value), dataType).apply(null) + } + + @transient + def inputRdd = if (!relation.hiveQlTable.isPartitioned) { + hadoopReader.makeRDDForTable(relation.hiveQlTable) + } else { + hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) + } + + /** + * Prunes partitions not involve the query plan. + * + * @param partitions All partitions of the relation. + * @return Partitions that are involved in the query plan. + */ + private[hive] def prunePartitions(partitions: Seq[HivePartition]) = { + boundPruningPred match { + case None => partitions + case Some(shouldKeep) => partitions.filter { part => + val dataTypes = relation.partitionKeys.map(_.dataType) + val castedValues = for ((value, dataType) <- part.getValues.zip(dataTypes)) yield { + castFromString(value, dataType) + } + + // Only partitioned values are needed here, since the predicate has already been bound to + // partition key attribute references. + val row = new GenericRow(castedValues.toArray) + shouldKeep.apply(row).asInstanceOf[Boolean] + } + } + } + + def execute() = { + inputRdd.map { row => + val values = row match { + case Array(deserializedRow: AnyRef, partitionKeys: Array[String]) => + attributeFunctions.map(_(deserializedRow, partitionKeys)) + case deserializedRow: AnyRef => + attributeFunctions.map(_(deserializedRow, Array.empty)) + } + buildRow(values.map { + case n: String if n.toLowerCase == "null" => null + case varchar: org.apache.hadoop.hive.common.`type`.HiveVarchar => varchar.getValue + case decimal: org.apache.hadoop.hive.common.`type`.HiveDecimal => + BigDecimal(decimal.bigDecimalValue) + case other => other + }) + } + } + + def output = attributes +} + +case class InsertIntoHiveTable( + table: MetastoreRelation, + partition: Map[String, Option[String]], + child: SparkPlan, + overwrite: Boolean) + (@transient sc: HiveContext) + extends UnaryNode { + + val outputClass = newSerializer(table.tableDesc).getSerializedClass + @transient private val hiveContext = new Context(sc.hiveconf) + @transient private val db = Hive.get(sc.hiveconf) + + private def newSerializer(tableDesc: TableDesc): Serializer = { + val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] + serializer.initialize(null, tableDesc.getProperties) + serializer + } + + override def otherCopyArgs = sc :: Nil + + def output = child.output + + /** + * Wraps with Hive types based on object inspector. + * TODO: Consolidate all hive OI/data interface code. + */ + protected def wrap(a: (Any, ObjectInspector)): Any = a match { + case (s: String, oi: JavaHiveVarcharObjectInspector) => new HiveVarchar(s, s.size) + case (bd: BigDecimal, oi: JavaHiveDecimalObjectInspector) => + new HiveDecimal(bd.underlying()) + case (row: Row, oi: StandardStructObjectInspector) => + val struct = oi.create() + row.zip(oi.getAllStructFieldRefs).foreach { + case (data, field) => + oi.setStructFieldData(struct, field, wrap(data, field.getFieldObjectInspector)) + } + struct + case (s: Seq[_], oi: ListObjectInspector) => + val wrappedSeq = s.map(wrap(_, oi.getListElementObjectInspector)) + seqAsJavaList(wrappedSeq) + case (obj, _) => obj + } + + def saveAsHiveFile( + rdd: RDD[Writable], + valueClass: Class[_], + fileSinkConf: FileSinkDesc, + conf: JobConf, + isCompressed: Boolean) { + if (valueClass == null) { + throw new SparkException("Output value class not set") + } + conf.setOutputValueClass(valueClass) + if (fileSinkConf.getTableInfo.getOutputFileFormatClassName == null) { + throw new SparkException("Output format class not set") + } + // Doesn't work in Scala 2.9 due to what may be a generics bug + // TODO: Should we uncomment this for Scala 2.10? + // conf.setOutputFormat(outputFormatClass) + conf.set("mapred.output.format.class", fileSinkConf.getTableInfo.getOutputFileFormatClassName) + if (isCompressed) { + // Please note that isCompressed, "mapred.output.compress", "mapred.output.compression.codec", + // and "mapred.output.compression.type" have no impact on ORC because it uses table properties + // to store compression information. + conf.set("mapred.output.compress", "true") + fileSinkConf.setCompressed(true) + fileSinkConf.setCompressCodec(conf.get("mapred.output.compression.codec")) + fileSinkConf.setCompressType(conf.get("mapred.output.compression.type")) + } + conf.setOutputCommitter(classOf[FileOutputCommitter]) + FileOutputFormat.setOutputPath( + conf, + SparkHiveHadoopWriter.createPathFromString(fileSinkConf.getDirName, conf)) + + logger.debug("Saving as hadoop file of type " + valueClass.getSimpleName) + + val writer = new SparkHiveHadoopWriter(conf, fileSinkConf) + writer.preSetup() + + def writeToFile(context: TaskContext, iter: Iterator[Writable]) { + // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it + // around by taking a mod. We expect that no task will be attempted 2 billion times. + val attemptNumber = (context.attemptId % Int.MaxValue).toInt + + writer.setup(context.stageId, context.partitionId, attemptNumber) + writer.open() + + var count = 0 + while(iter.hasNext) { + val record = iter.next() + count += 1 + writer.write(record) + } + + writer.close() + writer.commit() + } + + sc.sparkContext.runJob(rdd, writeToFile _) + writer.commitJob() + } + + /** + * Inserts all the rows in the table into Hive. Row objects are properly serialized with the + * `org.apache.hadoop.hive.serde2.SerDe` and the + * `org.apache.hadoop.mapred.OutputFormat` provided by the table definition. + */ + def execute() = { + val childRdd = child.execute() + assert(childRdd != null) + + // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer + // instances within the closure, since Serializer is not serializable while TableDesc is. + val tableDesc = table.tableDesc + val tableLocation = table.hiveQlTable.getDataLocation + val tmpLocation = hiveContext.getExternalTmpFileURI(tableLocation) + val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) + val rdd = childRdd.mapPartitions { iter => + val serializer = newSerializer(fileSinkConf.getTableInfo) + val standardOI = ObjectInspectorUtils + .getStandardObjectInspector( + fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, + ObjectInspectorCopyOption.JAVA) + .asInstanceOf[StructObjectInspector] + + iter.map { row => + // Casts Strings to HiveVarchars when necessary. + val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector) + val mappedRow = row.zip(fieldOIs).map(wrap) + + serializer.serialize(mappedRow.toArray, standardOI) + } + } + + // ORC stores compression information in table properties. While, there are other formats + // (e.g. RCFile) that rely on hadoop configurations to store compression information. + val jobConf = new JobConf(sc.hiveconf) + saveAsHiveFile( + rdd, + outputClass, + fileSinkConf, + jobConf, + sc.hiveconf.getBoolean("hive.exec.compress.output", false)) + + // TODO: Handle dynamic partitioning. + val outputPath = FileOutputFormat.getOutputPath(jobConf) + // Have to construct the format of dbname.tablename. + val qualifiedTableName = s"${table.databaseName}.${table.tableName}" + // TODO: Correctly set holdDDLTime. + // In most of the time, we should have holdDDLTime = false. + // holdDDLTime will be true when TOK_HOLD_DDLTIME presents in the query as a hint. + val holdDDLTime = false + if (partition.nonEmpty) { + val partitionSpec = partition.map { + case (key, Some(value)) => key -> value + case (key, None) => key -> "" // Should not reach here right now. + } + val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols(), partitionSpec) + db.validatePartitionNameCharacters(partVals) + // inheritTableSpecs is set to true. It should be set to false for a IMPORT query + // which is currently considered as a Hive native command. + val inheritTableSpecs = true + // TODO: Correctly set isSkewedStoreAsSubdir. + val isSkewedStoreAsSubdir = false + db.loadPartition( + outputPath, + qualifiedTableName, + partitionSpec, + overwrite, + holdDDLTime, + inheritTableSpecs, + isSkewedStoreAsSubdir) + } else { + db.loadTable( + outputPath, + qualifiedTableName, + overwrite, + holdDDLTime) + } + + // It would be nice to just return the childRdd unchanged so insert operations could be chained, + // however for now we return an empty list to simplify compatibility checks with hive, which + // does not return anything for insert operations. + // TODO: implement hive compatibility as rules. + sc.sparkContext.makeRDD(Nil, 1) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala new file mode 100644 index 0000000000..5e775d6a04 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -0,0 +1,467 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive + +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.hadoop.hive.serde2.{io => hiveIo} +import org.apache.hadoop.hive.serde2.objectinspector.primitive._ +import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} +import org.apache.hadoop.hive.ql.udf.generic._ +import org.apache.hadoop.hive.ql.exec.UDF +import org.apache.hadoop.{io => hadoopIo} + +import catalyst.analysis +import catalyst.expressions._ +import catalyst.types +import catalyst.types._ + +object HiveFunctionRegistry + extends analysis.FunctionRegistry with HiveFunctionFactory with HiveInspectors { + + def lookupFunction(name: String, children: Seq[Expression]): Expression = { + // We only look it up to see if it exists, but do not include it in the HiveUDF since it is + // not always serializable. + val functionInfo: FunctionInfo = Option(FunctionRegistry.getFunctionInfo(name)).getOrElse( + sys.error(s"Couldn't find function $name")) + + if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { + val function = createFunction[UDF](name) + val method = function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo)) + + lazy val expectedDataTypes = method.getParameterTypes.map(javaClassToDataType) + + HiveSimpleUdf( + name, + children.zip(expectedDataTypes).map { case (e, t) => Cast(e, t) } + ) + } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveGenericUdf(name, children) + } else if ( + classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveGenericUdaf(name, children) + + } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveGenericUdtf(name, Nil, children) + } else { + sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") + } + } + + def javaClassToDataType(clz: Class[_]): DataType = clz match { + case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType + case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType + case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType + case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType + case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType + case c: Class[_] if c == classOf[hadoopIo.Text] => StringType + case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType + case c: Class[_] if c == classOf[hadoopIo.LongWritable] => LongType + case c: Class[_] if c == classOf[hadoopIo.FloatWritable] => FloatType + case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType + case c: Class[_] if c == classOf[java.lang.String] => StringType + case c: Class[_] if c == java.lang.Short.TYPE => ShortType + case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType + case c: Class[_] if c == java.lang.Long.TYPE => LongType + case c: Class[_] if c == java.lang.Double.TYPE => DoubleType + case c: Class[_] if c == java.lang.Byte.TYPE => ByteType + case c: Class[_] if c == java.lang.Float.TYPE => FloatType + case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType + case c: Class[_] if c == classOf[java.lang.Short] => ShortType + case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType + case c: Class[_] if c == classOf[java.lang.Long] => LongType + case c: Class[_] if c == classOf[java.lang.Double] => DoubleType + case c: Class[_] if c == classOf[java.lang.Byte] => ByteType + case c: Class[_] if c == classOf[java.lang.Float] => FloatType + case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType + case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType)) + } +} + +trait HiveFunctionFactory { + def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name) + def getFunctionClass(name: String) = getFunctionInfo(name).getFunctionClass + def createFunction[UDFType](name: String) = + getFunctionClass(name).newInstance.asInstanceOf[UDFType] + + /** Converts hive types to native catalyst types. */ + def unwrap(a: Any): Any = a match { + case null => null + case i: hadoopIo.IntWritable => i.get + case t: hadoopIo.Text => t.toString + case l: hadoopIo.LongWritable => l.get + case d: hadoopIo.DoubleWritable => d.get() + case d: hiveIo.DoubleWritable => d.get + case s: hiveIo.ShortWritable => s.get + case b: hadoopIo.BooleanWritable => b.get() + case b: hiveIo.ByteWritable => b.get + case list: java.util.List[_] => list.map(unwrap) + case map: java.util.Map[_,_] => map.map { case (k, v) => (unwrap(k), unwrap(v)) }.toMap + case array: Array[_] => array.map(unwrap).toSeq + case p: java.lang.Short => p + case p: java.lang.Long => p + case p: java.lang.Float => p + case p: java.lang.Integer => p + case p: java.lang.Double => p + case p: java.lang.Byte => p + case p: java.lang.Boolean => p + case str: String => str + } +} + +abstract class HiveUdf + extends Expression with Logging with HiveFunctionFactory { + self: Product => + + type UDFType + type EvaluatedType = Any + + val name: String + + def nullable = true + def references = children.flatMap(_.references).toSet + + // FunctionInfo is not serializable so we must look it up here again. + lazy val functionInfo = getFunctionInfo(name) + lazy val function = createFunction[UDFType](name) + + override def toString = s"${nodeName}#${functionInfo.getDisplayName}(${children.mkString(",")})" +} + +case class HiveSimpleUdf(name: String, children: Seq[Expression]) extends HiveUdf { + import HiveFunctionRegistry._ + type UDFType = UDF + + @transient + protected lazy val method = + function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo)) + + @transient + lazy val dataType = javaClassToDataType(method.getReturnType) + + protected lazy val wrappers: Array[(Any) => AnyRef] = method.getParameterTypes.map { argClass => + val primitiveClasses = Seq( + Integer.TYPE, classOf[java.lang.Integer], classOf[java.lang.String], java.lang.Double.TYPE, + classOf[java.lang.Double], java.lang.Long.TYPE, classOf[java.lang.Long], + classOf[HiveDecimal], java.lang.Byte.TYPE, classOf[java.lang.Byte] + ) + val matchingConstructor = argClass.getConstructors.find { c => + c.getParameterTypes.size == 1 && primitiveClasses.contains(c.getParameterTypes.head) + } + + val constructor = matchingConstructor.getOrElse( + sys.error(s"No matching wrapper found, options: ${argClass.getConstructors.toSeq}.")) + + (a: Any) => { + logger.debug( + s"Wrapping $a of type ${if (a == null) "null" else a.getClass.getName} using $constructor.") + // We must make sure that primitives get boxed java style. + if (a == null) { + null + } else { + constructor.newInstance(a match { + case i: Int => i: java.lang.Integer + case bd: BigDecimal => new HiveDecimal(bd.underlying()) + case other: AnyRef => other + }).asInstanceOf[AnyRef] + } + } + } + + // TODO: Finish input output types. + override def apply(input: Row): Any = { + val evaluatedChildren = children.map(_.apply(input)) + // Wrap the function arguments in the expected types. + val args = evaluatedChildren.zip(wrappers).map { + case (arg, wrapper) => wrapper(arg) + } + + // Invoke the udf and unwrap the result. + unwrap(method.invoke(function, args: _*)) + } +} + +case class HiveGenericUdf( + name: String, + children: Seq[Expression]) extends HiveUdf with HiveInspectors { + import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ + type UDFType = GenericUDF + + @transient + protected lazy val argumentInspectors = children.map(_.dataType).map(toInspector) + + @transient + protected lazy val returnInspector = function.initialize(argumentInspectors.toArray) + + val dataType: DataType = inspectorToDataType(returnInspector) + + override def apply(input: Row): Any = { + returnInspector // Make sure initialized. + val args = children.map { v => + new DeferredObject { + override def prepare(i: Int) = {} + override def get(): AnyRef = wrap(v.apply(input)) + } + }.toArray + unwrap(function.evaluate(args)) + } +} + +trait HiveInspectors { + + def unwrapData(data: Any, oi: ObjectInspector): Any = oi match { + case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data) + case li: ListObjectInspector => + Option(li.getList(data)) + .map(_.map(unwrapData(_, li.getListElementObjectInspector)).toSeq) + .orNull + case mi: MapObjectInspector => + Option(mi.getMap(data)).map( + _.map { + case (k,v) => + (unwrapData(k, mi.getMapKeyObjectInspector), + unwrapData(v, mi.getMapValueObjectInspector)) + }.toMap).orNull + case si: StructObjectInspector => + val allRefs = si.getAllStructFieldRefs + new GenericRow( + allRefs.map(r => + unwrapData(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray) + } + + /** Converts native catalyst types to the types expected by Hive */ + def wrap(a: Any): AnyRef = a match { + case s: String => new hadoopIo.Text(s) + case i: Int => i: java.lang.Integer + case b: Boolean => b: java.lang.Boolean + case d: Double => d: java.lang.Double + case l: Long => l: java.lang.Long + case l: Short => l: java.lang.Short + case l: Byte => l: java.lang.Byte + case s: Seq[_] => seqAsJavaList(s.map(wrap)) + case m: Map[_,_] => + mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) }) + case null => null + } + + def toInspector(dataType: DataType): ObjectInspector = dataType match { + case ArrayType(tpe) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) + case MapType(keyType, valueType) => + ObjectInspectorFactory.getStandardMapObjectInspector( + toInspector(keyType), toInspector(valueType)) + case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector + case IntegerType => PrimitiveObjectInspectorFactory.javaIntObjectInspector + case DoubleType => PrimitiveObjectInspectorFactory.javaDoubleObjectInspector + case BooleanType => PrimitiveObjectInspectorFactory.javaBooleanObjectInspector + case LongType => PrimitiveObjectInspectorFactory.javaLongObjectInspector + case FloatType => PrimitiveObjectInspectorFactory.javaFloatObjectInspector + case ShortType => PrimitiveObjectInspectorFactory.javaShortObjectInspector + case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector + case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector + case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector + } + + def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match { + case s: StructObjectInspector => + StructType(s.getAllStructFieldRefs.map(f => { + types.StructField( + f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true) + })) + case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector)) + case m: MapObjectInspector => + MapType( + inspectorToDataType(m.getMapKeyObjectInspector), + inspectorToDataType(m.getMapValueObjectInspector)) + case _: WritableStringObjectInspector => StringType + case _: JavaStringObjectInspector => StringType + case _: WritableIntObjectInspector => IntegerType + case _: JavaIntObjectInspector => IntegerType + case _: WritableDoubleObjectInspector => DoubleType + case _: JavaDoubleObjectInspector => DoubleType + case _: WritableBooleanObjectInspector => BooleanType + case _: JavaBooleanObjectInspector => BooleanType + case _: WritableLongObjectInspector => LongType + case _: JavaLongObjectInspector => LongType + case _: WritableShortObjectInspector => ShortType + case _: JavaShortObjectInspector => ShortType + case _: WritableByteObjectInspector => ByteType + case _: JavaByteObjectInspector => ByteType + } + + implicit class typeInfoConversions(dt: DataType) { + import org.apache.hadoop.hive.serde2.typeinfo._ + import TypeInfoFactory._ + + def toTypeInfo: TypeInfo = dt match { + case BinaryType => binaryTypeInfo + case BooleanType => booleanTypeInfo + case ByteType => byteTypeInfo + case DoubleType => doubleTypeInfo + case FloatType => floatTypeInfo + case IntegerType => intTypeInfo + case LongType => longTypeInfo + case ShortType => shortTypeInfo + case StringType => stringTypeInfo + case DecimalType => decimalTypeInfo + case NullType => voidTypeInfo + } + } +} + +case class HiveGenericUdaf( + name: String, + children: Seq[Expression]) extends AggregateExpression + with HiveInspectors + with HiveFunctionFactory { + + type UDFType = AbstractGenericUDAFResolver + + protected lazy val resolver: AbstractGenericUDAFResolver = createFunction(name) + + protected lazy val objectInspector = { + resolver.getEvaluator(children.map(_.dataType.toTypeInfo).toArray) + .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) + } + + protected lazy val inspectors = children.map(_.dataType).map(toInspector) + + def dataType: DataType = inspectorToDataType(objectInspector) + + def nullable: Boolean = true + + def references: Set[Attribute] = children.map(_.references).flatten.toSet + + override def toString = s"$nodeName#$name(${children.mkString(",")})" + + def newInstance = new HiveUdafFunction(name, children, this) +} + +/** + * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a + * [[catalyst.expressions.Generator Generator]]. Note that the semantics of Generators do not allow + * Generators to maintain state in between input rows. Thus UDTFs that rely on partitioning + * dependent operations like calls to `close()` before producing output will not operate the same as + * in Hive. However, in practice this should not affect compatibility for most sane UDTFs + * (e.g. explode or GenericUDTFParseUrlTuple). + * + * Operators that require maintaining state in between input rows should instead be implemented as + * user defined aggregations, which have clean semantics even in a partitioned execution. + */ +case class HiveGenericUdtf( + name: String, + aliasNames: Seq[String], + children: Seq[Expression]) + extends Generator with HiveInspectors with HiveFunctionFactory { + + override def references = children.flatMap(_.references).toSet + + @transient + protected lazy val function: GenericUDTF = createFunction(name) + + protected lazy val inputInspectors = children.map(_.dataType).map(toInspector) + + protected lazy val outputInspectors = { + val structInspector = function.initialize(inputInspectors.toArray) + structInspector.getAllStructFieldRefs.map(_.getFieldObjectInspector) + } + + protected lazy val outputDataTypes = outputInspectors.map(inspectorToDataType) + + override protected def makeOutput() = { + // Use column names when given, otherwise c_1, c_2, ... c_n. + if (aliasNames.size == outputDataTypes.size) { + aliasNames.zip(outputDataTypes).map { + case (attrName, attrDataType) => + AttributeReference(attrName, attrDataType, nullable = true)() + } + } else { + outputDataTypes.zipWithIndex.map { + case (attrDataType, i) => + AttributeReference(s"c_$i", attrDataType, nullable = true)() + } + } + } + + override def apply(input: Row): TraversableOnce[Row] = { + outputInspectors // Make sure initialized. + + val inputProjection = new Projection(children) + val collector = new UDTFCollector + function.setCollector(collector) + + val udtInput = inputProjection(input).map(wrap).toArray + function.process(udtInput) + collector.collectRows() + } + + protected class UDTFCollector extends Collector { + var collected = new ArrayBuffer[Row] + + override def collect(input: java.lang.Object) { + // We need to clone the input here because implementations of + // GenericUDTF reuse the same object. Luckily they are always an array, so + // it is easy to clone. + collected += new GenericRow(input.asInstanceOf[Array[_]].map(unwrap)) + } + + def collectRows() = { + val toCollect = collected + collected = new ArrayBuffer[Row] + toCollect + } + } + + override def toString() = s"$nodeName#$name(${children.mkString(",")})" +} + +case class HiveUdafFunction( + functionName: String, + exprs: Seq[Expression], + base: AggregateExpression) + extends AggregateFunction + with HiveInspectors + with HiveFunctionFactory { + + def this() = this(null, null, null) + + private val resolver = createFunction[AbstractGenericUDAFResolver](functionName) + + private val inspectors = exprs.map(_.dataType).map(toInspector).toArray + + private val function = resolver.getEvaluator(exprs.map(_.dataType.toTypeInfo).toArray) + + private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) + + // Cast required to avoid type inference selecting a deprecated Hive API. + private val buffer = + function.getNewAggregationBuffer.asInstanceOf[GenericUDAFEvaluator.AbstractAggregationBuffer] + + override def apply(input: Row): Any = unwrapData(function.evaluate(buffer), returnInspector) + + @transient + val inputProjection = new Projection(exprs) + + def update(input: Row): Unit = { + val inputs = inputProjection(input).asInstanceOf[Seq[AnyRef]].toArray + function.iterate(buffer, inputs) + } +} diff --git a/sql/hive/src/test/resources/log4j.properties b/sql/hive/src/test/resources/log4j.properties new file mode 100644 index 0000000000..5e17e3b596 --- /dev/null +++ b/sql/hive/src/test/resources/log4j.properties @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file core/target/unit-tests.log +log4j.rootLogger=DEBUG, CA, FA + +#Console Appender +log4j.appender.CA=org.apache.log4j.ConsoleAppender +log4j.appender.CA.layout=org.apache.log4j.PatternLayout +log4j.appender.CA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c: %m%n +log4j.appender.CA.Threshold = WARN + + +#File Appender +log4j.appender.FA=org.apache.log4j.FileAppender +log4j.appender.FA.append=false +log4j.appender.FA.file=target/unit-tests.log +log4j.appender.FA.layout=org.apache.log4j.PatternLayout +log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c{1}: %m%n + +# Set the logger level of File Appender to WARN +log4j.appender.FA.Threshold = INFO + +# Some packages are noisy for no good reason. +log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false +log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF + +log4j.additivity.org.apache.hadoop.hive.metastore.RetryingHMSHandler=false +log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=OFF + +log4j.additivity.hive.ql.metadata.Hive=false +log4j.logger.hive.ql.metadata.Hive=OFF + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala new file mode 100644 index 0000000000..4b45e69860 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive +package execution + +import java.io.File + +/** + * A set of test cases based on the big-data-benchmark. + * https://amplab.cs.berkeley.edu/benchmark/ + */ +class BigDataBenchmarkSuite extends HiveComparisonTest { + import TestHive._ + + val testDataDirectory = new File("target/big-data-benchmark-testdata") + + val testTables = Seq( + TestTable( + "rankings", + s""" + |CREATE EXTERNAL TABLE rankings ( + | pageURL STRING, + | pageRank INT, + | avgDuration INT) + | ROW FORMAT DELIMITED FIELDS TERMINATED BY "," + | STORED AS TEXTFILE LOCATION "${new File(testDataDirectory, "rankings").getCanonicalPath}" + """.stripMargin.cmd), + TestTable( + "scratch", + s""" + |CREATE EXTERNAL TABLE scratch ( + | pageURL STRING, + | pageRank INT, + | avgDuration INT) + | ROW FORMAT DELIMITED FIELDS TERMINATED BY "," + | STORED AS TEXTFILE LOCATION "${new File(testDataDirectory, "scratch").getCanonicalPath}" + """.stripMargin.cmd), + TestTable( + "uservisits", + s""" + |CREATE EXTERNAL TABLE uservisits ( + | sourceIP STRING, + | destURL STRING, + | visitDate STRING, + | adRevenue DOUBLE, + | userAgent STRING, + | countryCode STRING, + | languageCode STRING, + | searchWord STRING, + | duration INT) + | ROW FORMAT DELIMITED FIELDS TERMINATED BY "," + | STORED AS TEXTFILE LOCATION "${new File(testDataDirectory, "uservisits").getCanonicalPath}" + """.stripMargin.cmd), + TestTable( + "documents", + s""" + |CREATE EXTERNAL TABLE documents (line STRING) + |STORED AS TEXTFILE + |LOCATION "${new File(testDataDirectory, "crawl").getCanonicalPath}" + """.stripMargin.cmd)) + + testTables.foreach(registerTestTable) + + if (!testDataDirectory.exists()) { + // TODO: Auto download the files on demand. + ignore("No data files found for BigDataBenchmark tests.") {} + } else { + createQueryTest("query1", + "SELECT pageURL, pageRank FROM rankings WHERE pageRank > 1") + + createQueryTest("query2", + "SELECT SUBSTR(sourceIP, 1, 10), SUM(adRevenue) FROM uservisits GROUP BY SUBSTR(sourceIP, 1, 10)") + + createQueryTest("query3", + """ + |SELECT sourceIP, + | sum(adRevenue) as totalRevenue, + | avg(pageRank) as pageRank + |FROM + | rankings R JOIN + | (SELECT sourceIP, destURL, adRevenue + | FROM uservisits UV + | WHERE UV.visitDate > "1980-01-01" + | AND UV.visitDate < "1980-04-01") + | NUV ON (R.pageURL = NUV.destURL) + |GROUP BY sourceIP + |ORDER BY totalRevenue DESC + |LIMIT 1 + """.stripMargin) + + createQueryTest("query4", + """ + |DROP TABLE IF EXISTS url_counts_partial; + |CREATE TABLE url_counts_partial AS + | SELECT TRANSFORM (line) + | USING 'python target/url_count.py' as (sourcePage, + | destPage, count) from documents; + |DROP TABLE IF EXISTS url_counts_total; + |CREATE TABLE url_counts_total AS + | SELECT SUM(count) AS totalCount, destpage + | FROM url_counts_partial GROUP BY destpage + |-- The following queries run, but generate different results in HIVE likely because the UDF is not deterministic + |-- given different input splits. + |-- SELECT CAST(SUM(count) AS INT) FROM url_counts_partial + |-- SELECT COUNT(*) FROM url_counts_partial + |-- SELECT * FROM url_counts_partial + |-- SELECT * FROM url_counts_total + """.stripMargin) + } +}
\ No newline at end of file diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala new file mode 100644 index 0000000000..a12ab23946 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark +package sql +package hive +package execution + + +import org.scalatest.{FunSuite, BeforeAndAfterAll} + +class ConcurrentHiveSuite extends FunSuite with BeforeAndAfterAll { + ignore("multiple instances not supported") { + test("Multiple Hive Instances") { + (1 to 10).map { i => + val ts = + new TestHiveContext(new SparkContext("local", s"TestSQLContext$i", new SparkConf())) + ts.executeSql("SHOW TABLES").toRdd.collect() + ts.executeSql("SELECT * FROM src").toRdd.collect() + ts.executeSql("SHOW TABLES").toRdd.collect() + } + } + } +}
\ No newline at end of file diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala new file mode 100644 index 0000000000..8a5b97b7a0 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -0,0 +1,379 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive +package execution + +import java.io._ +import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen} + +import catalyst.plans.logical.{ExplainCommand, NativeCommand} +import catalyst.plans._ +import catalyst.util._ + +import org.apache.spark.sql.execution.Sort + +/** + * Allows the creations of tests that execute the same query against both hive + * and catalyst, comparing the results. + * + * The "golden" results from Hive are cached in an retrieved both from the classpath and + * [[answerCache]] to speed up testing. + * + * See the documentation of public vals in this class for information on how test execution can be + * configured using system properties. + */ +abstract class HiveComparisonTest extends FunSuite with BeforeAndAfterAll with GivenWhenThen with Logging { + + /** + * When set, any cache files that result in test failures will be deleted. Used when the test + * harness or hive have been updated thus requiring new golden answers to be computed for some + * tests. Also prevents the classpath being used when looking for golden answers as these are + * usually stale. + */ + val recomputeCache = System.getProperty("spark.hive.recomputeCache") != null + + protected val shardRegEx = "(\\d+):(\\d+)".r + /** + * Allows multiple JVMs to be run in parallel, each responsible for portion of all test cases. + * Format `shardId:numShards`. Shard ids should be zero indexed. E.g. -Dspark.hive.testshard=0:4. + */ + val shardInfo = Option(System.getProperty("spark.hive.shard")).map { + case shardRegEx(id, total) => (id.toInt, total.toInt) + } + + protected val targetDir = new File("target") + + /** + * When set, this comma separated list is defines directories that contain the names of test cases + * that should be skipped. + * + * For example when `-Dspark.hive.skiptests=passed,hiveFailed` is specified and test cases listed + * in [[passedDirectory]] or [[hiveFailedDirectory]] will be skipped. + */ + val skipDirectories = + Option(System.getProperty("spark.hive.skiptests")) + .toSeq + .flatMap(_.split(",")) + .map(name => new File(targetDir, s"$suiteName.$name")) + + val runOnlyDirectories = + Option(System.getProperty("spark.hive.runonlytests")) + .toSeq + .flatMap(_.split(",")) + .map(name => new File(targetDir, s"$suiteName.$name")) + + /** The local directory with cached golden answer will be stored. */ + protected val answerCache = new File("src/test/resources/golden") + if (!answerCache.exists) { + answerCache.mkdir() + } + + /** The [[ClassLoader]] that contains test dependencies. Used to look for golden answers. */ + protected val testClassLoader = this.getClass.getClassLoader + + /** Directory containing a file for each test case that passes. */ + val passedDirectory = new File(targetDir, s"$suiteName.passed") + if (!passedDirectory.exists()) { + passedDirectory.mkdir() // Not atomic! + } + + /** Directory containing output of tests that fail to execute with Catalyst. */ + val failedDirectory = new File(targetDir, s"$suiteName.failed") + if (!failedDirectory.exists()) { + failedDirectory.mkdir() // Not atomic! + } + + /** Directory containing output of tests where catalyst produces the wrong answer. */ + val wrongDirectory = new File(targetDir, s"$suiteName.wrong") + if (!wrongDirectory.exists()) { + wrongDirectory.mkdir() // Not atomic! + } + + /** Directory containing output of tests where we fail to generate golden output with Hive. */ + val hiveFailedDirectory = new File(targetDir, s"$suiteName.hiveFailed") + if (!hiveFailedDirectory.exists()) { + hiveFailedDirectory.mkdir() // Not atomic! + } + + /** All directories that contain per-query output files */ + val outputDirectories = Seq( + passedDirectory, + failedDirectory, + wrongDirectory, + hiveFailedDirectory) + + protected val cacheDigest = java.security.MessageDigest.getInstance("MD5") + protected def getMd5(str: String): String = { + val digest = java.security.MessageDigest.getInstance("MD5") + digest.update(str.getBytes) + new java.math.BigInteger(1, digest.digest).toString(16) + } + + protected def prepareAnswer( + hiveQuery: TestHive.type#SqlQueryExecution, + answer: Seq[String]): Seq[String] = { + val orderedAnswer = hiveQuery.logical match { + // Clean out non-deterministic time schema info. + case _: NativeCommand => answer.filterNot(nonDeterministicLine).filterNot(_ == "") + case _: ExplainCommand => answer + case _ => + // TODO: Really we only care about the final total ordering here... + val isOrdered = hiveQuery.executedPlan.collect { + case s @ Sort(_, global, _) if global => s + }.nonEmpty + // If the query results aren't sorted, then sort them to ensure deterministic answers. + if (!isOrdered) answer.sorted else answer + } + orderedAnswer.map(cleanPaths) + } + + // TODO: Instead of filtering we should clean to avoid accidentally ignoring actual results. + lazy val nonDeterministicLineIndicators = Seq( + "CreateTime", + "transient_lastDdlTime", + "grantTime", + "lastUpdateTime", + "last_modified_time", + "Owner:", + // The following are hive specific schema parameters which we do not need to match exactly. + "numFiles", + "numRows", + "rawDataSize", + "totalSize", + "totalNumberFiles", + "maxFileSize", + "minFileSize" + ) + protected def nonDeterministicLine(line: String) = + nonDeterministicLineIndicators.map(line contains _).reduceLeft(_||_) + + /** + * Removes non-deterministic paths from `str` so cached answers will compare correctly. + */ + protected def cleanPaths(str: String): String = { + str.replaceAll("file:\\/.*\\/", "<PATH>") + } + + val installHooksCommand = "(?i)SET.*hooks".r + def createQueryTest(testCaseName: String, sql: String) { + // If test sharding is enable, skip tests that are not in the correct shard. + shardInfo.foreach { + case (shardId, numShards) if testCaseName.hashCode % numShards != shardId => return + case (shardId, _) => logger.debug(s"Shard $shardId includes test '$testCaseName'") + } + + // Skip tests found in directories specified by user. + skipDirectories + .map(new File(_, testCaseName)) + .filter(_.exists) + .foreach(_ => return) + + // If runonlytests is set, skip this test unless we find a file in one of the specified + // directories. + val runIndicators = + runOnlyDirectories + .map(new File(_, testCaseName)) + .filter(_.exists) + if (runOnlyDirectories.nonEmpty && runIndicators.isEmpty) { + logger.debug( + s"Skipping test '$testCaseName' not found in ${runOnlyDirectories.map(_.getCanonicalPath)}") + return + } + + test(testCaseName) { + logger.debug(s"=== HIVE TEST: $testCaseName ===") + + // Clear old output for this testcase. + outputDirectories.map(new File(_, testCaseName)).filter(_.exists()).foreach(_.delete()) + + val allQueries = sql.split("(?<=[^\\\\]);").map(_.trim).filterNot(q => q == "").toSeq + + // TODO: DOCUMENT UNSUPPORTED + val queryList = + allQueries + // In hive, setting the hive.outerjoin.supports.filters flag to "false" essentially tells + // the system to return the wrong answer. Since we have no intention of mirroring their + // previously broken behavior we simply filter out changes to this setting. + .filterNot(_ contains "hive.outerjoin.supports.filters") + + if (allQueries != queryList) + logger.warn(s"Simplifications made on unsupported operations for test $testCaseName") + + lazy val consoleTestCase = { + val quotes = "\"\"\"" + queryList.zipWithIndex.map { + case (query, i) => + s""" + |val q$i = $quotes$query$quotes.q + |q$i.stringResult() + """.stripMargin + }.mkString("\n== Console version of this test ==\n", "\n", "\n") + } + + try { + // MINOR HACK: You must run a query before calling reset the first time. + TestHive.sql("SHOW TABLES") + TestHive.reset() + + val hiveCacheFiles = queryList.zipWithIndex.map { + case (queryString, i) => + val cachedAnswerName = s"$testCaseName-$i-${getMd5(queryString)}" + new File(answerCache, cachedAnswerName) + } + + val hiveCachedResults = hiveCacheFiles.flatMap { cachedAnswerFile => + logger.debug(s"Looking for cached answer file $cachedAnswerFile.") + if (cachedAnswerFile.exists) { + Some(fileToString(cachedAnswerFile)) + } else { + logger.debug(s"File $cachedAnswerFile not found") + None + } + }.map { + case "" => Nil + case "\n" => Seq("") + case other => other.split("\n").toSeq + } + + val hiveResults: Seq[Seq[String]] = + if (hiveCachedResults.size == queryList.size) { + logger.info(s"Using answer cache for test: $testCaseName") + hiveCachedResults + } else { + + val hiveQueries = queryList.map(new TestHive.SqlQueryExecution(_)) + // Make sure we can at least parse everything before attempting hive execution. + hiveQueries.foreach(_.logical) + val computedResults = (queryList.zipWithIndex, hiveQueries, hiveCacheFiles).zipped.map { + case ((queryString, i), hiveQuery, cachedAnswerFile)=> + try { + // Hooks often break the harness and don't really affect our test anyway, don't + // even try running them. + if (installHooksCommand.findAllMatchIn(queryString).nonEmpty) + sys.error("hive exec hooks not supported for tests.") + + logger.warn(s"Running query ${i+1}/${queryList.size} with hive.") + // Analyze the query with catalyst to ensure test tables are loaded. + val answer = hiveQuery.analyzed match { + case _: ExplainCommand => Nil // No need to execute EXPLAIN queries as we don't check the output. + case _ => TestHive.runSqlHive(queryString) + } + + // We need to add a new line to non-empty answers so we can differentiate Seq() + // from Seq(""). + stringToFile( + cachedAnswerFile, answer.mkString("\n") + (if (answer.nonEmpty) "\n" else "")) + answer + } catch { + case e: Exception => + val errorMessage = + s""" + |Failed to generate golden answer for query: + |Error: ${e.getMessage} + |${stackTraceToString(e)} + |$queryString + |$consoleTestCase + """.stripMargin + stringToFile( + new File(hiveFailedDirectory, testCaseName), + errorMessage + consoleTestCase) + fail(errorMessage) + } + }.toSeq + TestHive.reset() + + computedResults + } + + // Run w/ catalyst + val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => + val query = new TestHive.SqlQueryExecution(queryString) + try { (query, prepareAnswer(query, query.stringResult())) } catch { + case e: Exception => + val errorMessage = + s""" + |Failed to execute query using catalyst: + |Error: ${e.getMessage} + |${stackTraceToString(e)} + |$query + |== HIVE - ${hive.size} row(s) == + |${hive.mkString("\n")} + | + |$consoleTestCase + """.stripMargin + stringToFile(new File(failedDirectory, testCaseName), errorMessage + consoleTestCase) + fail(errorMessage) + } + }.toSeq + + (queryList, hiveResults, catalystResults).zipped.foreach { + case (query, hive, (hiveQuery, catalyst)) => + // Check that the results match unless its an EXPLAIN query. + val preparedHive = prepareAnswer(hiveQuery,hive) + + if ((!hiveQuery.logical.isInstanceOf[ExplainCommand]) && preparedHive != catalyst) { + + val hivePrintOut = s"== HIVE - ${hive.size} row(s) ==" +: preparedHive + val catalystPrintOut = s"== CATALYST - ${catalyst.size} row(s) ==" +: catalyst + + val resultComparison = sideBySide(hivePrintOut, catalystPrintOut).mkString("\n") + + if (recomputeCache) { + logger.warn(s"Clearing cache files for failed test $testCaseName") + hiveCacheFiles.foreach(_.delete()) + } + + val errorMessage = + s""" + |Results do not match for $testCaseName: + |$hiveQuery\n${hiveQuery.analyzed.output.map(_.name).mkString("\t")} + |$resultComparison + """.stripMargin + + stringToFile(new File(wrongDirectory, testCaseName), errorMessage + consoleTestCase) + fail(errorMessage) + } + } + + // Touch passed file. + new FileOutputStream(new File(passedDirectory, testCaseName)).close() + } catch { + case tf: org.scalatest.exceptions.TestFailedException => throw tf + case originalException: Exception => + if (System.getProperty("spark.hive.canarytest") != null) { + // When we encounter an error we check to see if the environment is still okay by running a simple query. + // If this fails then we halt testing since something must have gone seriously wrong. + try { + new TestHive.SqlQueryExecution("SELECT key FROM src").stringResult() + TestHive.runSqlHive("SELECT key FROM src") + } catch { + case e: Exception => + logger.error(s"FATAL ERROR: Canary query threw $e This implies that the testing environment has likely been corrupted.") + // The testing setup traps exits so wait here for a long time so the developer can see when things started + // to go wrong. + Thread.sleep(1000000) + } + } + + // If the canary query didn't fail then the environment is still okay, so just throw the original exception. + throw originalException + } + } + } +}
\ No newline at end of file diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala new file mode 100644 index 0000000000..d010023f78 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -0,0 +1,708 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive +package execution + + +import java.io._ + +import util._ + +/** + * Runs the test cases that are included in the hive distribution. + */ +class HiveCompatibilitySuite extends HiveQueryFileTest { + // TODO: bundle in jar files... get from classpath + lazy val hiveQueryDir = TestHive.getHiveFile("ql/src/test/queries/clientpositive") + def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) + + /** A list of tests deemed out of scope currently and thus completely disregarded. */ + override def blackList = Seq( + // These tests use hooks that are not on the classpath and thus break all subsequent execution. + "hook_order", + "hook_context", + "mapjoin_hook", + "multi_sahooks", + "overridden_confs", + "query_properties", + "sample10", + "updateAccessTime", + "index_compact_binary_search", + "bucket_num_reducers", + "column_access_stats", + "concatenate_inherit_table_location", + + // Setting a default property does not seem to get reset and thus changes the answer for many + // subsequent tests. + "create_default_prop", + + // User/machine specific test answers, breaks the caching mechanism. + "authorization_3", + "authorization_5", + "keyword_1", + "misc_json", + "create_like_tbl_props", + "load_overwrite", + "alter_table_serde2", + "alter_table_not_sorted", + "alter_skewed_table", + "alter_partition_clusterby_sortby", + "alter_merge", + "alter_concatenate_indexed_table", + "protectmode2", + "describe_table", + "describe_comment_nonascii", + "udf5", + "udf_java_method", + + // Weird DDL differences result in failures on jenkins. + "create_like2", + "create_view_translate", + "partitions_json", + + // Timezone specific test answers. + "udf_unix_timestamp", + "udf_to_unix_timestamp", + + // Cant run without local map/reduce. + "index_auto_update", + "index_auto_self_join", + "index_stale.*", + "type_cast_1", + "index_compression", + "index_bitmap_compression", + "index_auto_multiple", + "index_auto_mult_tables_compact", + "index_auto_mult_tables", + "index_auto_file_format", + "index_auth", + "index_auto_empty", + "index_auto_partitioned", + "index_auto_unused", + "index_bitmap_auto_partitioned", + "ql_rewrite_gbtoidx", + "stats1.*", + "stats20", + "alter_merge_stats", + + // Hive seems to think 1.0 > NaN = true && 1.0 < NaN = false... which is wrong. + // http://stackoverflow.com/a/1573715 + "ops_comparison", + + // Tests that seems to never complete on hive... + "skewjoin", + "database", + + // These tests fail and and exit the JVM. + "auto_join18_multi_distinct", + "join18_multi_distinct", + "input44", + "input42", + "input_dfs", + "metadata_export_drop", + "repair", + + // Uses a serde that isn't on the classpath... breaks other tests. + "bucketizedhiveinputformat", + + // Avro tests seem to change the output format permanently thus breaking the answer cache, until + // we figure out why this is the case let just ignore all of avro related tests. + ".*avro.*", + + // Unique joins are weird and will require a lot of hacks (see comments in hive parser). + "uniquejoin", + + // Hive seems to get the wrong answer on some outer joins. MySQL agrees with catalyst. + "auto_join29", + + // No support for multi-alias i.e. udf as (e1, e2, e3). + "allcolref_in_udf", + + // No support for TestSerDe (not published afaik) + "alter1", + "input16", + + // No support for unpublished test udfs. + "autogen_colalias", + + // Hive does not support buckets. + ".*bucket.*", + + // No window support yet + ".*window.*", + + // Fails in hive with authorization errors. + "alter_rename_partition_authorization", + "authorization.*", + + // Hadoop version specific tests + "archive_corrupt", + + // No support for case sensitivity is resolution using hive properties atm. + "case_sensitivity" + ) + + /** + * The set of tests that are believed to be working in catalyst. Tests not on whiteList or + * blacklist are implicitly marked as ignored. + */ + override def whiteList = Seq( + "add_part_exist", + "add_partition_no_whitelist", + "add_partition_with_whitelist", + "alias_casted_column", + "alter2", + "alter4", + "alter5", + "alter_index", + "alter_merge_2", + "alter_partition_format_loc", + "alter_partition_protect_mode", + "alter_partition_with_whitelist", + "alter_table_serde", + "alter_varchar2", + "alter_view_as_select", + "ambiguous_col", + "auto_join0", + "auto_join1", + "auto_join10", + "auto_join11", + "auto_join12", + "auto_join13", + "auto_join14", + "auto_join14_hadoop20", + "auto_join15", + "auto_join17", + "auto_join18", + "auto_join19", + "auto_join2", + "auto_join20", + "auto_join21", + "auto_join22", + "auto_join23", + "auto_join24", + "auto_join25", + "auto_join26", + "auto_join27", + "auto_join28", + "auto_join3", + "auto_join30", + "auto_join31", + "auto_join32", + "auto_join4", + "auto_join5", + "auto_join6", + "auto_join7", + "auto_join8", + "auto_join9", + "auto_join_filters", + "auto_join_nulls", + "auto_join_reordering_values", + "auto_sortmerge_join_1", + "auto_sortmerge_join_10", + "auto_sortmerge_join_11", + "auto_sortmerge_join_12", + "auto_sortmerge_join_15", + "auto_sortmerge_join_2", + "auto_sortmerge_join_3", + "auto_sortmerge_join_4", + "auto_sortmerge_join_5", + "auto_sortmerge_join_6", + "auto_sortmerge_join_7", + "auto_sortmerge_join_8", + "auto_sortmerge_join_9", + "binary_constant", + "binarysortable_1", + "combine1", + "compute_stats_binary", + "compute_stats_boolean", + "compute_stats_double", + "compute_stats_table", + "compute_stats_long", + "compute_stats_string", + "convert_enum_to_string", + "correlationoptimizer11", + "correlationoptimizer15", + "correlationoptimizer2", + "correlationoptimizer3", + "correlationoptimizer4", + "correlationoptimizer6", + "correlationoptimizer7", + "correlationoptimizer8", + "count", + "create_like_view", + "create_nested_type", + "create_skewed_table1", + "create_struct_table", + "ct_case_insensitive", + "database_location", + "database_properties", + "decimal_join", + "default_partition_name", + "delimiter", + "desc_non_existent_tbl", + "describe_comment_indent", + "describe_database_json", + "describe_pretty", + "describe_syntax", + "describe_table_json", + "diff_part_input_formats", + "disable_file_format_check", + "drop_function", + "drop_index", + "drop_partitions_filter", + "drop_partitions_filter2", + "drop_partitions_filter3", + "drop_partitions_ignore_protection", + "drop_table", + "drop_table2", + "drop_view", + "escape_clusterby1", + "escape_distributeby1", + "escape_orderby1", + "escape_sortby1", + "fetch_aggregation", + "filter_join_breaktask", + "filter_join_breaktask2", + "groupby1", + "groupby11", + "groupby1_map", + "groupby1_map_nomap", + "groupby1_map_skew", + "groupby1_noskew", + "groupby4", + "groupby4_map", + "groupby4_map_skew", + "groupby4_noskew", + "groupby5", + "groupby5_map", + "groupby5_map_skew", + "groupby5_noskew", + "groupby6", + "groupby6_map", + "groupby6_map_skew", + "groupby6_noskew", + "groupby7", + "groupby7_map", + "groupby7_map_multi_single_reducer", + "groupby7_map_skew", + "groupby7_noskew", + "groupby8_map", + "groupby8_map_skew", + "groupby8_noskew", + "groupby_distinct_samekey", + "groupby_multi_single_reducer2", + "groupby_mutli_insert_common_distinct", + "groupby_neg_float", + "groupby_sort_10", + "groupby_sort_6", + "groupby_sort_8", + "groupby_sort_test_1", + "implicit_cast1", + "innerjoin", + "inoutdriver", + "input", + "input0", + "input11", + "input11_limit", + "input12", + "input12_hadoop20", + "input19", + "input1_limit", + "input22", + "input23", + "input24", + "input25", + "input26", + "input28", + "input2_limit", + "input40", + "input41", + "input4_cb_delim", + "input6", + "input7", + "input8", + "input9", + "input_limit", + "input_part0", + "input_part1", + "input_part10", + "input_part10_win", + "input_part2", + "input_part3", + "input_part4", + "input_part5", + "input_part6", + "input_part7", + "input_part8", + "input_part9", + "inputddl4", + "inputddl7", + "inputddl8", + "insert_compressed", + "join0", + "join1", + "join10", + "join11", + "join12", + "join13", + "join14", + "join14_hadoop20", + "join15", + "join16", + "join17", + "join18", + "join19", + "join2", + "join20", + "join21", + "join22", + "join23", + "join24", + "join25", + "join26", + "join27", + "join28", + "join29", + "join3", + "join30", + "join31", + "join32", + "join33", + "join34", + "join35", + "join36", + "join37", + "join38", + "join39", + "join4", + "join40", + "join41", + "join5", + "join6", + "join7", + "join8", + "join9", + "join_1to1", + "join_array", + "join_casesensitive", + "join_empty", + "join_filters", + "join_hive_626", + "join_nulls", + "join_reorder2", + "join_reorder3", + "join_reorder4", + "join_star", + "join_view", + "lateral_view_cp", + "lateral_view_ppd", + "lineage1", + "literal_double", + "literal_ints", + "literal_string", + "load_dyn_part7", + "load_file_with_space_in_the_name", + "louter_join_ppr", + "mapjoin_distinct", + "mapjoin_mapjoin", + "mapjoin_subquery", + "mapjoin_subquery2", + "mapjoin_test_outer", + "mapreduce3", + "mapreduce7", + "merge1", + "merge2", + "mergejoins", + "mergejoins_mixed", + "multiMapJoin1", + "multiMapJoin2", + "multi_join_union", + "multigroupby_singlemr", + "noalias_subq1", + "nomore_ambiguous_table_col", + "nonblock_op_deduplicate", + "notable_alias1", + "notable_alias2", + "nullgroup", + "nullgroup2", + "nullgroup3", + "nullgroup4", + "nullgroup4_multi_distinct", + "nullgroup5", + "nullinput", + "nullinput2", + "nullscript", + "optional_outer", + "order", + "order2", + "outer_join_ppr", + "part_inherit_tbl_props", + "part_inherit_tbl_props_empty", + "part_inherit_tbl_props_with_star", + "partition_schema1", + "partition_varchar1", + "plan_json", + "ppd1", + "ppd_constant_where", + "ppd_gby", + "ppd_gby2", + "ppd_gby_join", + "ppd_join", + "ppd_join2", + "ppd_join3", + "ppd_join_filter", + "ppd_outer_join1", + "ppd_outer_join2", + "ppd_outer_join3", + "ppd_outer_join4", + "ppd_outer_join5", + "ppd_random", + "ppd_repeated_alias", + "ppd_udf_col", + "ppd_union", + "ppr_allchildsarenull", + "ppr_pushdown", + "ppr_pushdown2", + "ppr_pushdown3", + "progress_1", + "protectmode", + "push_or", + "query_with_semi", + "quote1", + "quote2", + "reduce_deduplicate_exclude_join", + "rename_column", + "router_join_ppr", + "select_as_omitted", + "select_unquote_and", + "select_unquote_not", + "select_unquote_or", + "serde_reported_schema", + "set_variable_sub", + "show_describe_func_quotes", + "show_functions", + "show_partitions", + "skewjoinopt13", + "skewjoinopt18", + "skewjoinopt9", + "smb_mapjoin_1", + "smb_mapjoin_10", + "smb_mapjoin_13", + "smb_mapjoin_14", + "smb_mapjoin_15", + "smb_mapjoin_16", + "smb_mapjoin_17", + "smb_mapjoin_2", + "smb_mapjoin_21", + "smb_mapjoin_25", + "smb_mapjoin_3", + "smb_mapjoin_4", + "smb_mapjoin_5", + "smb_mapjoin_8", + "sort", + "sort_merge_join_desc_1", + "sort_merge_join_desc_2", + "sort_merge_join_desc_3", + "sort_merge_join_desc_4", + "sort_merge_join_desc_5", + "sort_merge_join_desc_6", + "sort_merge_join_desc_7", + "stats0", + "stats_empty_partition", + "subq2", + "tablename_with_select", + "touch", + "type_widening", + "udaf_collect_set", + "udaf_corr", + "udaf_covar_pop", + "udaf_covar_samp", + "udf2", + "udf6", + "udf9", + "udf_10_trims", + "udf_E", + "udf_PI", + "udf_abs", + "udf_acos", + "udf_add", + "udf_array", + "udf_array_contains", + "udf_ascii", + "udf_asin", + "udf_atan", + "udf_avg", + "udf_bigint", + "udf_bin", + "udf_bitmap_and", + "udf_bitmap_empty", + "udf_bitmap_or", + "udf_bitwise_and", + "udf_bitwise_not", + "udf_bitwise_or", + "udf_bitwise_xor", + "udf_boolean", + "udf_case", + "udf_ceil", + "udf_ceiling", + "udf_concat", + "udf_concat_insert2", + "udf_concat_ws", + "udf_conv", + "udf_cos", + "udf_count", + "udf_date_add", + "udf_date_sub", + "udf_datediff", + "udf_day", + "udf_dayofmonth", + "udf_degrees", + "udf_div", + "udf_double", + "udf_exp", + "udf_field", + "udf_find_in_set", + "udf_float", + "udf_floor", + "udf_format_number", + "udf_from_unixtime", + "udf_greaterthan", + "udf_greaterthanorequal", + "udf_hex", + "udf_if", + "udf_index", + "udf_int", + "udf_isnotnull", + "udf_isnull", + "udf_java_method", + "udf_lcase", + "udf_length", + "udf_lessthan", + "udf_lessthanorequal", + "udf_like", + "udf_ln", + "udf_log", + "udf_log10", + "udf_log2", + "udf_lower", + "udf_lpad", + "udf_ltrim", + "udf_map", + "udf_minute", + "udf_modulo", + "udf_month", + "udf_negative", + "udf_not", + "udf_notequal", + "udf_notop", + "udf_nvl", + "udf_or", + "udf_parse_url", + "udf_positive", + "udf_pow", + "udf_power", + "udf_radians", + "udf_rand", + "udf_regexp", + "udf_regexp_extract", + "udf_regexp_replace", + "udf_repeat", + "udf_rlike", + "udf_round", + "udf_round_3", + "udf_rpad", + "udf_rtrim", + "udf_second", + "udf_sign", + "udf_sin", + "udf_smallint", + "udf_space", + "udf_sqrt", + "udf_std", + "udf_stddev", + "udf_stddev_pop", + "udf_stddev_samp", + "udf_string", + "udf_substring", + "udf_subtract", + "udf_sum", + "udf_tan", + "udf_tinyint", + "udf_to_byte", + "udf_to_date", + "udf_to_double", + "udf_to_float", + "udf_to_long", + "udf_to_short", + "udf_translate", + "udf_trim", + "udf_ucase", + "udf_upper", + "udf_var_pop", + "udf_var_samp", + "udf_variance", + "udf_weekofyear", + "udf_when", + "udf_xpath", + "udf_xpath_boolean", + "udf_xpath_double", + "udf_xpath_float", + "udf_xpath_int", + "udf_xpath_long", + "udf_xpath_short", + "udf_xpath_string", + "unicode_notation", + "union10", + "union11", + "union13", + "union14", + "union15", + "union16", + "union17", + "union18", + "union19", + "union2", + "union20", + "union22", + "union23", + "union24", + "union26", + "union27", + "union28", + "union29", + "union30", + "union31", + "union34", + "union4", + "union5", + "union6", + "union7", + "union8", + "union9", + "union_lateralview", + "union_ppr", + "union_remove_3", + "union_remove_6", + "union_script", + "varchar_2", + "varchar_join1", + "varchar_union1" + ) +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala new file mode 100644 index 0000000000..f0a4ec3c02 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive +package execution + +import java.io._ + +import catalyst.util._ + +/** + * A framework for running the query tests that are listed as a set of text files. + * + * TestSuites that derive from this class must provide a map of testCaseName -> testCaseFiles that should be included. + * Additionally, there is support for whitelisting and blacklisting tests as development progresses. + */ +abstract class HiveQueryFileTest extends HiveComparisonTest { + /** A list of tests deemed out of scope and thus completely disregarded */ + def blackList: Seq[String] = Nil + + /** + * The set of tests that are believed to be working in catalyst. Tests not in whiteList + * blacklist are implicitly marked as ignored. + */ + def whiteList: Seq[String] = ".*" :: Nil + + def testCases: Seq[(String, File)] + + val runAll = + !(System.getProperty("spark.hive.alltests") == null) || + runOnlyDirectories.nonEmpty || + skipDirectories.nonEmpty + + val whiteListProperty = "spark.hive.whitelist" + // Allow the whiteList to be overridden by a system property + val realWhiteList = + Option(System.getProperty(whiteListProperty)).map(_.split(",").toSeq).getOrElse(whiteList) + + // Go through all the test cases and add them to scala test. + testCases.sorted.foreach { + case (testCaseName, testCaseFile) => + if (blackList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_)) { + logger.debug(s"Blacklisted test skipped $testCaseName") + } else if (realWhiteList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_) || runAll) { + // Build a test case and submit it to scala test framework... + val queriesString = fileToString(testCaseFile) + createQueryTest(testCaseName, queriesString) + } else { + // Only output warnings for the built in whitelist as this clutters the output when the user + // trying to execute a single test from the commandline. + if(System.getProperty(whiteListProperty) == null && !runAll) + ignore(testCaseName) {} + } + } +}
\ No newline at end of file diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala new file mode 100644 index 0000000000..28a5d260b3 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive +package execution + + +/** + * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. + */ +class HiveQuerySuite extends HiveComparisonTest { + import TestHive._ + + createQueryTest("Simple Average", + "SELECT AVG(key) FROM src") + + createQueryTest("Simple Average + 1", + "SELECT AVG(key) + 1.0 FROM src") + + createQueryTest("Simple Average + 1 with group", + "SELECT AVG(key) + 1.0, value FROM src group by value") + + createQueryTest("string literal", + "SELECT 'test' FROM src") + + createQueryTest("Escape sequences", + """SELECT key, '\\\t\\' FROM src WHERE key = 86""") + + createQueryTest("IgnoreExplain", + """EXPLAIN SELECT key FROM src""") + + createQueryTest("trivial join where clause", + "SELECT * FROM src a JOIN src b WHERE a.key = b.key") + + createQueryTest("trivial join ON clause", + "SELECT * FROM src a JOIN src b ON a.key = b.key") + + createQueryTest("small.cartesian", + "SELECT a.key, b.key FROM (SELECT key FROM src WHERE key < 1) a JOIN (SELECT key FROM src WHERE key = 2) b") + + createQueryTest("length.udf", + "SELECT length(\"test\") FROM src LIMIT 1") + + ignore("partitioned table scan") { + createQueryTest("partitioned table scan", + "SELECT ds, hr, key, value FROM srcpart") + } + + createQueryTest("hash", + "SELECT hash('test') FROM src LIMIT 1") + + createQueryTest("create table as", + """ + |CREATE TABLE createdtable AS SELECT * FROM src; + |SELECT * FROM createdtable + """.stripMargin) + + createQueryTest("create table as with db name", + """ + |CREATE DATABASE IF NOT EXISTS testdb; + |CREATE TABLE testdb.createdtable AS SELECT * FROM default.src; + |SELECT * FROM testdb.createdtable; + |DROP DATABASE IF EXISTS testdb CASCADE + """.stripMargin) + + createQueryTest("insert table with db name", + """ + |CREATE DATABASE IF NOT EXISTS testdb; + |CREATE TABLE testdb.createdtable like default.src; + |INSERT INTO TABLE testdb.createdtable SELECT * FROM default.src; + |SELECT * FROM testdb.createdtable; + |DROP DATABASE IF EXISTS testdb CASCADE + """.stripMargin) + + createQueryTest("insert into and insert overwrite", + """ + |CREATE TABLE createdtable like src; + |INSERT INTO TABLE createdtable SELECT * FROM src; + |INSERT INTO TABLE createdtable SELECT * FROM src1; + |SELECT * FROM createdtable; + |INSERT OVERWRITE TABLE createdtable SELECT * FROM src WHERE key = 86; + |SELECT * FROM createdtable; + """.stripMargin) + + createQueryTest("transform", + "SELECT TRANSFORM (key) USING 'cat' AS (tKey) FROM src") + + createQueryTest("LIKE", + "SELECT * FROM src WHERE value LIKE '%1%'") + + createQueryTest("DISTINCT", + "SELECT DISTINCT key, value FROM src") + + ignore("empty aggregate input") { + createQueryTest("empty aggregate input", + "SELECT SUM(key) FROM (SELECT * FROM src LIMIT 0) a") + } + + createQueryTest("lateral view1", + "SELECT tbl.* FROM src LATERAL VIEW explode(array(1,2)) tbl as a") + + createQueryTest("lateral view2", + "SELECT * FROM src LATERAL VIEW explode(array(1,2)) tbl") + + + createQueryTest("lateral view3", + "FROM src SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX") + + createQueryTest("lateral view4", + """ + |create table src_lv1 (key string, value string); + |create table src_lv2 (key string, value string); + | + |FROM src + |insert overwrite table src_lv1 SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX + |insert overwrite table src_lv2 SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX + """.stripMargin) + + createQueryTest("lateral view5", + "FROM src SELECT explode(array(key+3, key+4))") + + createQueryTest("lateral view6", + "SELECT * FROM src LATERAL VIEW explode(map(key+3,key+4)) D as k, v") + + test("sampling") { + sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") + } + +}
\ No newline at end of file diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala new file mode 100644 index 0000000000..0dd79faa15 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive +package execution + +/** + * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. + */ +class HiveResolutionSuite extends HiveComparisonTest { + import TestHive._ + + createQueryTest("table.attr", + "SELECT src.key FROM src ORDER BY key LIMIT 1") + + createQueryTest("database.table", + "SELECT key FROM default.src ORDER BY key LIMIT 1") + + createQueryTest("database.table table.attr", + "SELECT src.key FROM default.src ORDER BY key LIMIT 1") + + createQueryTest("alias.attr", + "SELECT a.key FROM src a ORDER BY key LIMIT 1") + + createQueryTest("subquery-alias.attr", + "SELECT a.key FROM (SELECT * FROM src ORDER BY key LIMIT 1) a") + + createQueryTest("quoted alias.attr", + "SELECT `a`.`key` FROM src a ORDER BY key LIMIT 1") + + createQueryTest("attr", + "SELECT key FROM src a ORDER BY key LIMIT 1") + + createQueryTest("alias.*", + "SELECT a.* FROM src a ORDER BY key LIMIT 1") + + /** + * Negative examples. Currently only left here for documentation purposes. + * TODO(marmbrus): Test that catalyst fails on these queries. + */ + + /* SemanticException [Error 10009]: Line 1:7 Invalid table alias 'src' + createQueryTest("table.*", + "SELECT src.* FROM src a ORDER BY key LIMIT 1") */ + + /* Invalid table alias or column reference 'src': (possible column names are: key, value) + createQueryTest("tableName.attr from aliased subquery", + "SELECT src.key FROM (SELECT * FROM src ORDER BY key LIMIT 1) a") */ + +}
\ No newline at end of file diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala new file mode 100644 index 0000000000..c2264926f4 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive +package execution + +/** + * A set of tests that validates support for Hive SerDe. + */ +class HiveSerDeSuite extends HiveComparisonTest { + createQueryTest( + "Read and write with LazySimpleSerDe (tab separated)", + "SELECT * from serdeins") + + createQueryTest("Read with RegexSerDe", "SELECT * FROM sales") + + createQueryTest("Read with AvroSerDe", "SELECT * FROM episodes") +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala new file mode 100644 index 0000000000..bb33583e5f --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +/** + * A set of tests that validate type promotion rules. + */ +class HiveTypeCoercionSuite extends HiveComparisonTest { + + val baseTypes = Seq("1", "1.0", "1L", "1S", "1Y", "'1'") + + baseTypes.foreach { i => + baseTypes.foreach { j => + createQueryTest(s"$i + $j", s"SELECT $i + $j FROM src LIMIT 1") + } + } +}
\ No newline at end of file diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala new file mode 100644 index 0000000000..8542f42aa9 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive +package execution + +import scala.collection.JavaConversions._ + +import org.apache.spark.sql.hive.TestHive + +/** + * A set of test cases that validate partition and column pruning. + */ +class PruningSuite extends HiveComparisonTest { + // Column pruning tests + + createPruningTest("Column pruning: with partitioned table", + "SELECT key FROM srcpart WHERE ds = '2008-04-08' LIMIT 3", + Seq("key"), + Seq("key", "ds"), + Seq( + Seq("2008-04-08", "11"), + Seq("2008-04-08", "12"))) + + createPruningTest("Column pruning: with non-partitioned table", + "SELECT key FROM src WHERE key > 10 LIMIT 3", + Seq("key"), + Seq("key"), + Seq.empty) + + createPruningTest("Column pruning: with multiple projects", + "SELECT c1 FROM (SELECT key AS c1 FROM src WHERE key > 10) t1 LIMIT 3", + Seq("c1"), + Seq("key"), + Seq.empty) + + createPruningTest("Column pruning: projects alias substituting", + "SELECT c1 AS c2 FROM (SELECT key AS c1 FROM src WHERE key > 10) t1 LIMIT 3", + Seq("c2"), + Seq("key"), + Seq.empty) + + createPruningTest("Column pruning: filter alias in-lining", + "SELECT c1 FROM (SELECT key AS c1 FROM src WHERE key > 10) t1 WHERE c1 < 100 LIMIT 3", + Seq("c1"), + Seq("key"), + Seq.empty) + + createPruningTest("Column pruning: without filters", + "SELECT c1 FROM (SELECT key AS c1 FROM src) t1 LIMIT 3", + Seq("c1"), + Seq("key"), + Seq.empty) + + createPruningTest("Column pruning: simple top project without aliases", + "SELECT key FROM (SELECT key FROM src WHERE key > 10) t1 WHERE key < 100 LIMIT 3", + Seq("key"), + Seq("key"), + Seq.empty) + + createPruningTest("Column pruning: non-trivial top project with aliases", + "SELECT c1 * 2 AS double FROM (SELECT key AS c1 FROM src WHERE key > 10) t1 LIMIT 3", + Seq("double"), + Seq("key"), + Seq.empty) + + // Partition pruning tests + + createPruningTest("Partition pruning: non-partitioned, non-trivial project", + "SELECT key * 2 AS double FROM src WHERE value IS NOT NULL", + Seq("double"), + Seq("key", "value"), + Seq.empty) + + createPruningTest("Partiton pruning: non-partitioned table", + "SELECT value FROM src WHERE key IS NOT NULL", + Seq("value"), + Seq("value", "key"), + Seq.empty) + + createPruningTest("Partition pruning: with filter on string partition key", + "SELECT value, hr FROM srcpart1 WHERE ds = '2008-04-08'", + Seq("value", "hr"), + Seq("value", "hr", "ds"), + Seq( + Seq("2008-04-08", "11"), + Seq("2008-04-08", "12"))) + + createPruningTest("Partition pruning: with filter on int partition key", + "SELECT value, hr FROM srcpart1 WHERE hr < 12", + Seq("value", "hr"), + Seq("value", "hr"), + Seq( + Seq("2008-04-08", "11"), + Seq("2008-04-09", "11"))) + + createPruningTest("Partition pruning: left only 1 partition", + "SELECT value, hr FROM srcpart1 WHERE ds = '2008-04-08' AND hr < 12", + Seq("value", "hr"), + Seq("value", "hr", "ds"), + Seq( + Seq("2008-04-08", "11"))) + + createPruningTest("Partition pruning: all partitions pruned", + "SELECT value, hr FROM srcpart1 WHERE ds = '2014-01-27' AND hr = 11", + Seq("value", "hr"), + Seq("value", "hr", "ds"), + Seq.empty) + + createPruningTest("Partition pruning: pruning with both column key and partition key", + "SELECT value, hr FROM srcpart1 WHERE value IS NOT NULL AND hr < 12", + Seq("value", "hr"), + Seq("value", "hr"), + Seq( + Seq("2008-04-08", "11"), + Seq("2008-04-09", "11"))) + + def createPruningTest( + testCaseName: String, + sql: String, + expectedOutputColumns: Seq[String], + expectedScannedColumns: Seq[String], + expectedPartValues: Seq[Seq[String]]) = { + test(s"$testCaseName - pruning test") { + val plan = new TestHive.SqlQueryExecution(sql).executedPlan + val actualOutputColumns = plan.output.map(_.name) + val (actualScannedColumns, actualPartValues) = plan.collect { + case p @ HiveTableScan(columns, relation, _) => + val columnNames = columns.map(_.name) + val partValues = p.prunePartitions(relation.hiveQlPartitions).map(_.getValues) + (columnNames, partValues) + }.head + + assert(actualOutputColumns sameElements expectedOutputColumns, "Output columns mismatch") + assert(actualScannedColumns sameElements expectedScannedColumns, "Scanned columns mismatch") + + assert( + actualPartValues.length === expectedPartValues.length, + "Partition value count mismatches") + + for ((actual, expected) <- actualPartValues.zip(expectedPartValues)) { + assert(actual sameElements expected, "Partition values mismatch") + } + } + + // Creates a query test to compare query results generated by Hive and Catalyst. + createQueryTest(s"$testCaseName - query test", sql) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala new file mode 100644 index 0000000000..ee90061c7c --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import java.io.File + +import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, FunSuite} + +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.sql.hive.TestHive + + +class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { + + val filename = getTempFilePath("parquettest").getCanonicalFile.toURI.toString + + // runs a SQL and optionally resolves one Parquet table + def runQuery(querystr: String, tableName: Option[String] = None, filename: Option[String] = None): Array[Row] = { + // call to resolve references in order to get CREATE TABLE AS to work + val query = TestHive + .parseSql(querystr) + val finalQuery = + if (tableName.nonEmpty && filename.nonEmpty) + resolveParquetTable(tableName.get, filename.get, query) + else + query + TestHive.executePlan(finalQuery) + .toRdd + .collect() + } + + // stores a query output to a Parquet file + def storeQuery(querystr: String, filename: String): Unit = { + val query = WriteToFile( + filename, + TestHive.parseSql(querystr)) + TestHive + .executePlan(query) + .stringResult() + } + + /** + * TODO: This function is necessary as long as there is no notion of a Catalog for + * Parquet tables. Once such a thing exists this functionality should be moved there. + */ + def resolveParquetTable(tableName: String, filename: String, plan: LogicalPlan): LogicalPlan = { + TestHive.loadTestTable("src") // may not be loaded now + plan.transform { + case relation @ UnresolvedRelation(databaseName, name, alias) => + if (name == tableName) + ParquetRelation(tableName, filename) + else + relation + case op @ InsertIntoCreatedTable(databaseName, name, child) => + if (name == tableName) { + // note: at this stage the plan is not yet analyzed but Parquet needs to know the schema + // and for that we need the child to be resolved + val relation = ParquetRelation.create( + filename, + TestHive.analyzer(child), + TestHive.sparkContext.hadoopConfiguration, + Some(tableName)) + InsertIntoTable( + relation.asInstanceOf[BaseRelation], + Map.empty, + child, + overwrite = false) + } else + op + } + } + + override def beforeAll() { + // write test data + ParquetTestData.writeFile + // Override initial Parquet test table + TestHive.catalog.registerTable(Some[String]("parquet"), "testsource", ParquetTestData.testData) + } + + override def afterAll() { + ParquetTestData.testFile.delete() + } + + override def beforeEach() { + new File(filename).getAbsoluteFile.delete() + } + + override def afterEach() { + new File(filename).getAbsoluteFile.delete() + } + + test("SELECT on Parquet table") { + val rdd = runQuery("SELECT * FROM parquet.testsource") + assert(rdd != null) + assert(rdd.forall(_.size == 6)) + } + + test("Simple column projection + filter on Parquet table") { + val rdd = runQuery("SELECT myboolean, mylong FROM parquet.testsource WHERE myboolean=true") + assert(rdd.size === 5, "Filter returned incorrect number of rows") + assert(rdd.forall(_.getBoolean(0)), "Filter returned incorrect Boolean field value") + } + + test("Converting Hive to Parquet Table via WriteToFile") { + storeQuery("SELECT * FROM src", filename) + val rddOne = runQuery("SELECT * FROM src").sortBy(_.getInt(0)) + val rddTwo = runQuery("SELECT * from ptable", Some("ptable"), Some(filename)).sortBy(_.getInt(0)) + compareRDDs(rddOne, rddTwo, "src (Hive)", Seq("key:Int", "value:String")) + } + + test("INSERT OVERWRITE TABLE Parquet table") { + storeQuery("SELECT * FROM parquet.testsource", filename) + runQuery("INSERT OVERWRITE TABLE ptable SELECT * FROM parquet.testsource", Some("ptable"), Some(filename)) + runQuery("INSERT OVERWRITE TABLE ptable SELECT * FROM parquet.testsource", Some("ptable"), Some(filename)) + val rddCopy = runQuery("SELECT * FROM ptable", Some("ptable"), Some(filename)) + val rddOrig = runQuery("SELECT * FROM parquet.testsource") + compareRDDs(rddOrig, rddCopy, "parquet.testsource", ParquetTestData.testSchemaFieldNames) + } + + test("CREATE TABLE AS Parquet table") { + runQuery("CREATE TABLE ptable AS SELECT * FROM src", Some("ptable"), Some(filename)) + val rddCopy = runQuery("SELECT * FROM ptable", Some("ptable"), Some(filename)) + .sortBy[Int](_.apply(0) match { + case x: Int => x + case _ => 0 + }) + val rddOrig = runQuery("SELECT * FROM src").sortBy(_.getInt(0)) + compareRDDs(rddOrig, rddCopy, "src (Hive)", Seq("key:Int", "value:String")) + } + + private def compareRDDs(rddOne: Array[Row], rddTwo: Array[Row], tableName: String, fieldNames: Seq[String]) { + var counter = 0 + (rddOne, rddTwo).zipped.foreach { + (a,b) => (a,b).zipped.toArray.zipWithIndex.foreach { + case ((value_1:Array[Byte], value_2:Array[Byte]), index) => + assert(new String(value_1) === new String(value_2), s"table $tableName row ${counter} field ${fieldNames(index)} don't match") + case ((value_1, value_2), index) => + assert(value_1 === value_2, s"table $tableName row $counter field ${fieldNames(index)} don't match") + } + counter = counter + 1 + } + } +} |