diff options
author | Cheng Lian <lian.cs.zju@gmail.com> | 2014-07-25 12:20:49 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-07-25 12:20:49 -0700 |
commit | 06dc0d2c6b69c5d59b4d194ced2ac85bfe2e05e2 (patch) | |
tree | 5f3189e690ac7f2fe68426c6763f7900e9aba5b6 /sql/hive-thriftserver | |
parent | 32bcf9af94b39f2c509eb54f8565fb659c70ca97 (diff) | |
download | spark-06dc0d2c6b69c5d59b4d194ced2ac85bfe2e05e2.tar.gz spark-06dc0d2c6b69c5d59b4d194ced2ac85bfe2e05e2.tar.bz2 spark-06dc0d2c6b69c5d59b4d194ced2ac85bfe2e05e2.zip |
[SPARK-2410][SQL] Merging Hive Thrift/JDBC server
JIRA issue:
- Main: [SPARK-2410](https://issues.apache.org/jira/browse/SPARK-2410)
- Related: [SPARK-2678](https://issues.apache.org/jira/browse/SPARK-2678)
Cherry picked the Hive Thrift/JDBC server from [branch-1.0-jdbc](https://github.com/apache/spark/tree/branch-1.0-jdbc).
(Thanks chenghao-intel for his initial contribution of the Spark SQL CLI.)
TODO
- [x] Use `spark-submit` to launch the server, the CLI and beeline
- [x] Migration guideline draft for Shark users
----
Hit by a bug in `SparkSubmitArguments` while working on this PR: all application options that are recognized by `SparkSubmitArguments` are stolen as `SparkSubmit` options. For example:
```bash
$ spark-submit --class org.apache.hive.beeline.BeeLine spark-internal --help
```
This actually shows usage information of `SparkSubmit` rather than `BeeLine`.
~~Fixed this bug here since the `spark-internal` related stuff also touches `SparkSubmitArguments` and I'd like to avoid conflict.~~
**UPDATE** The bug mentioned above is now tracked by [SPARK-2678](https://issues.apache.org/jira/browse/SPARK-2678). Decided to revert changes to this bug since it involves more subtle considerations and worth a separate PR.
Author: Cheng Lian <lian.cs.zju@gmail.com>
Closes #1399 from liancheng/thriftserver and squashes the following commits:
090beea [Cheng Lian] Revert changes related to SPARK-2678, decided to move them to another PR
21c6cf4 [Cheng Lian] Updated Spark SQL programming guide docs
fe0af31 [Cheng Lian] Reordered spark-submit options in spark-shell[.cmd]
199e3fb [Cheng Lian] Disabled MIMA for hive-thriftserver
1083e9d [Cheng Lian] Fixed failed test suites
7db82a1 [Cheng Lian] Fixed spark-submit application options handling logic
9cc0f06 [Cheng Lian] Starts beeline with spark-submit
cfcf461 [Cheng Lian] Updated documents and build scripts for the newly added hive-thriftserver profile
061880f [Cheng Lian] Addressed all comments by @pwendell
7755062 [Cheng Lian] Adapts test suites to spark-submit settings
40bafef [Cheng Lian] Fixed more license header issues
e214aab [Cheng Lian] Added missing license headers
b8905ba [Cheng Lian] Fixed minor issues in spark-sql and start-thriftserver.sh
f975d22 [Cheng Lian] Updated docs for Hive compatibility and Shark migration guide draft
3ad4e75 [Cheng Lian] Starts spark-sql shell with spark-submit
a5310d1 [Cheng Lian] Make HiveThriftServer2 play well with spark-submit
61f39f4 [Cheng Lian] Starts Hive Thrift server via spark-submit
2c4c539 [Cheng Lian] Cherry picked the Hive Thrift server
Diffstat (limited to 'sql/hive-thriftserver')
13 files changed, 1303 insertions, 0 deletions
diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml new file mode 100644 index 0000000000..7fac90fdc5 --- /dev/null +++ b/sql/hive-thriftserver/pom.xml @@ -0,0 +1,82 @@ +<?xml version="1.0" encoding="UTF-8"?> +<!-- + ~ Licensed to the Apache Software Foundation (ASF) under one or more + ~ contributor license agreements. See the NOTICE file distributed with + ~ this work for additional information regarding copyright ownership. + ~ The ASF licenses this file to You under the Apache License, Version 2.0 + ~ (the "License"); you may not use this file except in compliance with + ~ the License. You may obtain a copy of the License at + ~ + ~ http://www.apache.org/licenses/LICENSE-2.0 + ~ + ~ Unless required by applicable law or agreed to in writing, software + ~ distributed under the License is distributed on an "AS IS" BASIS, + ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + ~ See the License for the specific language governing permissions and + ~ limitations under the License. + --> + +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + <parent> + <groupId>org.apache.spark</groupId> + <artifactId>spark-parent</artifactId> + <version>1.1.0-SNAPSHOT</version> + <relativePath>../../pom.xml</relativePath> + </parent> + + <groupId>org.apache.spark</groupId> + <artifactId>spark-hive-thriftserver_2.10</artifactId> + <packaging>jar</packaging> + <name>Spark Project Hive</name> + <url>http://spark.apache.org/</url> + <properties> + <sbt.project.name>hive-thriftserver</sbt.project.name> + </properties> + + <dependencies> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-hive_${scala.binary.version}</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> + <groupId>org.spark-project.hive</groupId> + <artifactId>hive-cli</artifactId> + <version>${hive.version}</version> + </dependency> + <dependency> + <groupId>org.spark-project.hive</groupId> + <artifactId>hive-jdbc</artifactId> + <version>${hive.version}</version> + </dependency> + <dependency> + <groupId>org.spark-project.hive</groupId> + <artifactId>hive-beeline</artifactId> + <version>${hive.version}</version> + </dependency> + <dependency> + <groupId>org.scalatest</groupId> + <artifactId>scalatest_${scala.binary.version}</artifactId> + <scope>test</scope> + </dependency> + </dependencies> + <build> + <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory> + <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory> + <plugins> + <plugin> + <groupId>org.scalatest</groupId> + <artifactId>scalatest-maven-plugin</artifactId> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-deploy-plugin</artifactId> + <configuration> + <skip>true</skip> + </configuration> + </plugin> + </plugins> + </build> +</project> diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala new file mode 100644 index 0000000000..ddbc2a79fb --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ + +import org.apache.commons.logging.LogFactory +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hive.service.cli.thrift.ThriftBinaryCLIService +import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} + +import org.apache.spark.sql.Logging +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ + +/** + * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a + * `HiveThriftServer2` thrift server. + */ +private[hive] object HiveThriftServer2 extends Logging { + var LOG = LogFactory.getLog(classOf[HiveServer2]) + + def main(args: Array[String]) { + val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2") + + if (!optionsProcessor.process(args)) { + logger.warn("Error starting HiveThriftServer2 with given arguments") + System.exit(-1) + } + + val ss = new SessionState(new HiveConf(classOf[SessionState])) + + // Set all properties specified via command line. + val hiveConf: HiveConf = ss.getConf + hiveConf.getAllProperties.toSeq.sortBy(_._1).foreach { case (k, v) => + logger.debug(s"HiveConf var: $k=$v") + } + + SessionState.start(ss) + + logger.info("Starting SparkContext") + SparkSQLEnv.init() + SessionState.start(ss) + + Runtime.getRuntime.addShutdownHook( + new Thread() { + override def run() { + SparkSQLEnv.sparkContext.stop() + } + } + ) + + try { + val server = new HiveThriftServer2(SparkSQLEnv.hiveContext) + server.init(hiveConf) + server.start() + logger.info("HiveThriftServer2 started") + } catch { + case e: Exception => + logger.error("Error starting HiveThriftServer2", e) + System.exit(-1) + } + } +} + +private[hive] class HiveThriftServer2(hiveContext: HiveContext) + extends HiveServer2 + with ReflectedCompositeService { + + override def init(hiveConf: HiveConf) { + val sparkSqlCliService = new SparkSQLCLIService(hiveContext) + setSuperField(this, "cliService", sparkSqlCliService) + addService(sparkSqlCliService) + + val thriftCliService = new ThriftBinaryCLIService(sparkSqlCliService) + setSuperField(this, "thriftCLIService", thriftCliService) + addService(thriftCliService) + + initCompositeService(hiveConf) + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala new file mode 100644 index 0000000000..599294dfbb --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +private[hive] object ReflectionUtils { + def setSuperField(obj : Object, fieldName: String, fieldValue: Object) { + setAncestorField(obj, 1, fieldName, fieldValue) + } + + def setAncestorField(obj: AnyRef, level: Int, fieldName: String, fieldValue: AnyRef) { + val ancestor = Iterator.iterate[Class[_]](obj.getClass)(_.getSuperclass).drop(level).next() + val field = ancestor.getDeclaredField(fieldName) + field.setAccessible(true) + field.set(obj, fieldValue) + } + + def getSuperField[T](obj: AnyRef, fieldName: String): T = { + getAncestorField[T](obj, 1, fieldName) + } + + def getAncestorField[T](clazz: Object, level: Int, fieldName: String): T = { + val ancestor = Iterator.iterate[Class[_]](clazz.getClass)(_.getSuperclass).drop(level).next() + val field = ancestor.getDeclaredField(fieldName) + field.setAccessible(true) + field.get(clazz).asInstanceOf[T] + } + + def invokeStatic(clazz: Class[_], methodName: String, args: (Class[_], AnyRef)*): AnyRef = { + invoke(clazz, null, methodName, args: _*) + } + + def invoke( + clazz: Class[_], + obj: AnyRef, + methodName: String, + args: (Class[_], AnyRef)*): AnyRef = { + + val (types, values) = args.unzip + val method = clazz.getDeclaredMethod(methodName, types: _*) + method.setAccessible(true) + method.invoke(obj, values.toSeq: _*) + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala new file mode 100755 index 0000000000..27268ecb92 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ + +import java.io._ +import java.util.{ArrayList => JArrayList} + +import jline.{ConsoleReader, History} +import org.apache.commons.lang.StringUtils +import org.apache.commons.logging.LogFactory +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor} +import org.apache.hadoop.hive.common.LogUtils.LogInitializationException +import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, LogUtils} +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.exec.Utilities +import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.thrift.transport.TSocket + +import org.apache.spark.sql.Logging + +private[hive] object SparkSQLCLIDriver { + private var prompt = "spark-sql" + private var continuedPrompt = "".padTo(prompt.length, ' ') + private var transport:TSocket = _ + + installSignalHandler() + + /** + * Install an interrupt callback to cancel all Spark jobs. In Hive's CliDriver#processLine(), + * a signal handler will invoke this registered callback if a Ctrl+C signal is detected while + * a command is being processed by the current thread. + */ + def installSignalHandler() { + HiveInterruptUtils.add(new HiveInterruptCallback { + override def interrupt() { + // Handle remote execution mode + if (SparkSQLEnv.sparkContext != null) { + SparkSQLEnv.sparkContext.cancelAllJobs() + } else { + if (transport != null) { + // Force closing of TCP connection upon session termination + transport.getSocket.close() + } + } + } + }) + } + + def main(args: Array[String]) { + val oproc = new OptionsProcessor() + if (!oproc.process_stage1(args)) { + System.exit(1) + } + + // NOTE: It is critical to do this here so that log4j is reinitialized + // before any of the other core hive classes are loaded + var logInitFailed = false + var logInitDetailMessage: String = null + try { + logInitDetailMessage = LogUtils.initHiveLog4j() + } catch { + case e: LogInitializationException => + logInitFailed = true + logInitDetailMessage = e.getMessage + } + + val sessionState = new CliSessionState(new HiveConf(classOf[SessionState])) + + sessionState.in = System.in + try { + sessionState.out = new PrintStream(System.out, true, "UTF-8") + sessionState.info = new PrintStream(System.err, true, "UTF-8") + sessionState.err = new PrintStream(System.err, true, "UTF-8") + } catch { + case e: UnsupportedEncodingException => System.exit(3) + } + + if (!oproc.process_stage2(sessionState)) { + System.exit(2) + } + + if (!sessionState.getIsSilent) { + if (logInitFailed) System.err.println(logInitDetailMessage) + else SessionState.getConsole.printInfo(logInitDetailMessage) + } + + // Set all properties specified via command line. + val conf: HiveConf = sessionState.getConf + sessionState.cmdProperties.entrySet().foreach { item: java.util.Map.Entry[Object, Object] => + conf.set(item.getKey.asInstanceOf[String], item.getValue.asInstanceOf[String]) + sessionState.getOverriddenConfigurations.put( + item.getKey.asInstanceOf[String], item.getValue.asInstanceOf[String]) + } + + SessionState.start(sessionState) + + // Clean up after we exit + Runtime.getRuntime.addShutdownHook( + new Thread() { + override def run() { + SparkSQLEnv.stop() + } + } + ) + + // "-h" option has been passed, so connect to Hive thrift server. + if (sessionState.getHost != null) { + sessionState.connect() + if (sessionState.isRemoteMode) { + prompt = s"[${sessionState.getHost}:${sessionState.getPort}]" + prompt + continuedPrompt = "".padTo(prompt.length, ' ') + } + } + + if (!sessionState.isRemoteMode && !ShimLoader.getHadoopShims.usesJobShell()) { + // Hadoop-20 and above - we need to augment classpath using hiveconf + // components. + // See also: code in ExecDriver.java + var loader = conf.getClassLoader + val auxJars = HiveConf.getVar(conf, HiveConf.ConfVars.HIVEAUXJARS) + if (StringUtils.isNotBlank(auxJars)) { + loader = Utilities.addToClassPath(loader, StringUtils.split(auxJars, ",")) + } + conf.setClassLoader(loader) + Thread.currentThread().setContextClassLoader(loader) + } + + val cli = new SparkSQLCLIDriver + cli.setHiveVariables(oproc.getHiveVariables) + + // TODO work around for set the log output to console, because the HiveContext + // will set the output into an invalid buffer. + sessionState.in = System.in + try { + sessionState.out = new PrintStream(System.out, true, "UTF-8") + sessionState.info = new PrintStream(System.err, true, "UTF-8") + sessionState.err = new PrintStream(System.err, true, "UTF-8") + } catch { + case e: UnsupportedEncodingException => System.exit(3) + } + + // Execute -i init files (always in silent mode) + cli.processInitFiles(sessionState) + + if (sessionState.execString != null) { + System.exit(cli.processLine(sessionState.execString)) + } + + try { + if (sessionState.fileName != null) { + System.exit(cli.processFile(sessionState.fileName)) + } + } catch { + case e: FileNotFoundException => + System.err.println(s"Could not open input file for reading. (${e.getMessage})") + System.exit(3) + } + + val reader = new ConsoleReader() + reader.setBellEnabled(false) + // reader.setDebug(new PrintWriter(new FileWriter("writer.debug", true))) + CliDriver.getCommandCompletor.foreach((e) => reader.addCompletor(e)) + + val historyDirectory = System.getProperty("user.home") + + try { + if (new File(historyDirectory).exists()) { + val historyFile = historyDirectory + File.separator + ".hivehistory" + reader.setHistory(new History(new File(historyFile))) + } else { + System.err.println("WARNING: Directory for Hive history file: " + historyDirectory + + " does not exist. History will not be available during this session.") + } + } catch { + case e: Exception => + System.err.println("WARNING: Encountered an error while trying to initialize Hive's " + + "history file. History will not be available during this session.") + System.err.println(e.getMessage) + } + + val clientTransportTSocketField = classOf[CliSessionState].getDeclaredField("transport") + clientTransportTSocketField.setAccessible(true) + + transport = clientTransportTSocketField.get(sessionState).asInstanceOf[TSocket] + + var ret = 0 + var prefix = "" + val currentDB = ReflectionUtils.invokeStatic(classOf[CliDriver], "getFormattedDb", + classOf[HiveConf] -> conf, classOf[CliSessionState] -> sessionState) + + def promptWithCurrentDB = s"$prompt$currentDB" + def continuedPromptWithDBSpaces = continuedPrompt + ReflectionUtils.invokeStatic( + classOf[CliDriver], "spacesForString", classOf[String] -> currentDB) + + var currentPrompt = promptWithCurrentDB + var line = reader.readLine(currentPrompt + "> ") + + while (line != null) { + if (prefix.nonEmpty) { + prefix += '\n' + } + + if (line.trim().endsWith(";") && !line.trim().endsWith("\\;")) { + line = prefix + line + ret = cli.processLine(line, true) + prefix = "" + currentPrompt = promptWithCurrentDB + } else { + prefix = prefix + line + currentPrompt = continuedPromptWithDBSpaces + } + + line = reader.readLine(currentPrompt + "> ") + } + + sessionState.close() + + System.exit(ret) + } +} + +private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { + private val sessionState = SessionState.get().asInstanceOf[CliSessionState] + + private val LOG = LogFactory.getLog("CliDriver") + + private val console = new SessionState.LogHelper(LOG) + + private val conf: Configuration = + if (sessionState != null) sessionState.getConf else new Configuration() + + // Force initializing SparkSQLEnv. This is put here but not object SparkSQLCliDriver + // because the Hive unit tests do not go through the main() code path. + if (!sessionState.isRemoteMode) { + SparkSQLEnv.init() + } + + override def processCmd(cmd: String): Int = { + val cmd_trimmed: String = cmd.trim() + val tokens: Array[String] = cmd_trimmed.split("\\s+") + val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() + if (cmd_trimmed.toLowerCase.equals("quit") || + cmd_trimmed.toLowerCase.equals("exit") || + tokens(0).equalsIgnoreCase("source") || + cmd_trimmed.startsWith("!") || + tokens(0).toLowerCase.equals("list") || + sessionState.isRemoteMode) { + val start = System.currentTimeMillis() + super.processCmd(cmd) + val end = System.currentTimeMillis() + val timeTaken: Double = (end - start) / 1000.0 + console.printInfo(s"Time taken: $timeTaken seconds") + 0 + } else { + var ret = 0 + val hconf = conf.asInstanceOf[HiveConf] + val proc: CommandProcessor = CommandProcessorFactory.get(tokens(0), hconf) + + if (proc != null) { + if (proc.isInstanceOf[Driver]) { + val driver = new SparkSQLDriver + + driver.init() + val out = sessionState.out + val start:Long = System.currentTimeMillis() + if (sessionState.getIsVerbose) { + out.println(cmd) + } + + ret = driver.run(cmd).getResponseCode + if (ret != 0) { + driver.close() + return ret + } + + val res = new JArrayList[String]() + + if (HiveConf.getBoolVar(conf, HiveConf.ConfVars.HIVE_CLI_PRINT_HEADER)) { + // Print the column names. + Option(driver.getSchema.getFieldSchemas).map { fields => + out.println(fields.map(_.getName).mkString("\t")) + } + } + + try { + while (!out.checkError() && driver.getResults(res)) { + res.foreach(out.println) + res.clear() + } + } catch { + case e:IOException => + console.printError( + s"""Failed with exception ${e.getClass.getName}: ${e.getMessage} + |${org.apache.hadoop.util.StringUtils.stringifyException(e)} + """.stripMargin) + ret = 1 + } + + val cret = driver.close() + if (ret == 0) { + ret = cret + } + + val end = System.currentTimeMillis() + if (end > start) { + val timeTaken:Double = (end - start) / 1000.0 + console.printInfo(s"Time taken: $timeTaken seconds", null) + } + + // Destroy the driver to release all the locks. + driver.destroy() + } else { + if (sessionState.getIsVerbose) { + sessionState.out.println(tokens(0) + " " + cmd_1) + } + ret = proc.run(cmd_1).getResponseCode + } + } + ret + } + } +} + diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala new file mode 100644 index 0000000000..42cbf363b2 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ + +import java.io.IOException +import java.util.{List => JList} +import javax.security.auth.login.LoginException + +import org.apache.commons.logging.Log +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.hive.service.Service.STATE +import org.apache.hive.service.auth.HiveAuthFactory +import org.apache.hive.service.cli.CLIService +import org.apache.hive.service.{AbstractService, Service, ServiceException} + +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ + +private[hive] class SparkSQLCLIService(hiveContext: HiveContext) + extends CLIService + with ReflectedCompositeService { + + override def init(hiveConf: HiveConf) { + setSuperField(this, "hiveConf", hiveConf) + + val sparkSqlSessionManager = new SparkSQLSessionManager(hiveContext) + setSuperField(this, "sessionManager", sparkSqlSessionManager) + addService(sparkSqlSessionManager) + + try { + HiveAuthFactory.loginFromKeytab(hiveConf) + val serverUserName = ShimLoader.getHadoopShims + .getShortUserName(ShimLoader.getHadoopShims.getUGIForConf(hiveConf)) + setSuperField(this, "serverUserName", serverUserName) + } catch { + case e @ (_: IOException | _: LoginException) => + throw new ServiceException("Unable to login to kerberos with given principal/keytab", e) + } + + initCompositeService(hiveConf) + } +} + +private[thriftserver] trait ReflectedCompositeService { this: AbstractService => + def initCompositeService(hiveConf: HiveConf) { + // Emulating `CompositeService.init(hiveConf)` + val serviceList = getAncestorField[JList[Service]](this, 2, "serviceList") + serviceList.foreach(_.init(hiveConf)) + + // Emulating `AbstractService.init(hiveConf)` + invoke(classOf[AbstractService], this, "ensureCurrentState", classOf[STATE] -> STATE.NOTINITED) + setAncestorField(this, 3, "hiveConf", hiveConf) + invoke(classOf[AbstractService], this, "changeState", classOf[STATE] -> STATE.INITED) + getAncestorField[Log](this, 3, "LOG").info(s"Service: $getName is inited.") + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala new file mode 100644 index 0000000000..5202aa9903 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ + +import java.util.{ArrayList => JArrayList} + +import org.apache.commons.lang.exception.ExceptionUtils +import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} +import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse + +import org.apache.spark.sql.Logging +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} + +private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveContext) + extends Driver with Logging { + + private var tableSchema: Schema = _ + private var hiveResponse: Seq[String] = _ + + override def init(): Unit = { + } + + private def getResultSetSchema(query: context.QueryExecution): Schema = { + val analyzed = query.analyzed + logger.debug(s"Result Schema: ${analyzed.output}") + if (analyzed.output.size == 0) { + new Schema(new FieldSchema("Response code", "string", "") :: Nil, null) + } else { + val fieldSchemas = analyzed.output.map { attr => + new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") + } + + new Schema(fieldSchemas, null) + } + } + + override def run(command: String): CommandProcessorResponse = { + val execution = context.executePlan(context.hql(command).logicalPlan) + + // TODO unify the error code + try { + hiveResponse = execution.stringResult() + tableSchema = getResultSetSchema(execution) + new CommandProcessorResponse(0) + } catch { + case cause: Throwable => + logger.error(s"Failed in [$command]", cause) + new CommandProcessorResponse(-3, ExceptionUtils.getFullStackTrace(cause), null) + } + } + + override def close(): Int = { + hiveResponse = null + tableSchema = null + 0 + } + + override def getSchema: Schema = tableSchema + + override def getResults(res: JArrayList[String]): Boolean = { + if (hiveResponse == null) { + false + } else { + res.addAll(hiveResponse) + hiveResponse = null + true + } + } + + override def destroy() { + super.destroy() + hiveResponse = null + tableSchema = null + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala new file mode 100644 index 0000000000..451c3bd7b9 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import org.apache.hadoop.hive.ql.session.SessionState + +import org.apache.spark.scheduler.{SplitInfo, StatsReportListener} +import org.apache.spark.sql.Logging +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.{SparkConf, SparkContext} + +/** A singleton object for the master program. The slaves should not access this. */ +private[hive] object SparkSQLEnv extends Logging { + logger.debug("Initializing SparkSQLEnv") + + var hiveContext: HiveContext = _ + var sparkContext: SparkContext = _ + + def init() { + if (hiveContext == null) { + sparkContext = new SparkContext(new SparkConf() + .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}")) + + sparkContext.addSparkListener(new StatsReportListener()) + + hiveContext = new HiveContext(sparkContext) { + @transient override lazy val sessionState = SessionState.get() + @transient override lazy val hiveconf = sessionState.getConf + } + } + } + + /** Cleans up and shuts down the Spark SQL environments. */ + def stop() { + logger.debug("Shutting down Spark SQL Environment") + // Stop the SparkContext + if (SparkSQLEnv.sparkContext != null) { + sparkContext.stop() + sparkContext = null + hiveContext = null + } + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala new file mode 100644 index 0000000000..6b3275b4ea --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.util.concurrent.Executors + +import org.apache.commons.logging.Log +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hive.service.cli.session.SessionManager + +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ +import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager + +private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) + extends SessionManager + with ReflectedCompositeService { + + override def init(hiveConf: HiveConf) { + setSuperField(this, "hiveConf", hiveConf) + + val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) + setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) + getAncestorField[Log](this, 3, "LOG").info( + s"HiveServer2: Async execution pool size $backgroundPoolSize") + + val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) + setSuperField(this, "operationManager", sparkSqlOperationManager) + addService(sparkSqlOperationManager) + + initCompositeService(hiveConf) + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala new file mode 100644 index 0000000000..a4e1f3e762 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver.server + +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer +import scala.math.{random, round} + +import java.sql.Timestamp +import java.util.{Map => JMap} + +import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.hive.service.cli._ +import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager} +import org.apache.hive.service.cli.session.HiveSession + +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} +import org.apache.spark.sql.{Logging, SchemaRDD, Row => SparkRow} + +/** + * Executes queries using Spark SQL, and maintains a list of handles to active queries. + */ +class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManager with Logging { + val handleToOperation = ReflectionUtils + .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation") + + override def newExecuteStatementOperation( + parentSession: HiveSession, + statement: String, + confOverlay: JMap[String, String], + async: Boolean): ExecuteStatementOperation = synchronized { + + val operation = new ExecuteStatementOperation(parentSession, statement, confOverlay) { + private var result: SchemaRDD = _ + private var iter: Iterator[SparkRow] = _ + private var dataTypes: Array[DataType] = _ + + def close(): Unit = { + // RDDs will be cleaned automatically upon garbage collection. + logger.debug("CLOSING") + } + + def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { + if (!iter.hasNext) { + new RowSet() + } else { + val maxRows = maxRowsL.toInt // Do you really want a row batch larger than Int Max? No. + var curRow = 0 + var rowSet = new ArrayBuffer[Row](maxRows) + + while (curRow < maxRows && iter.hasNext) { + val sparkRow = iter.next() + val row = new Row() + var curCol = 0 + + while (curCol < sparkRow.length) { + dataTypes(curCol) match { + case StringType => + row.addString(sparkRow(curCol).asInstanceOf[String]) + case IntegerType => + row.addColumnValue(ColumnValue.intValue(sparkRow.getInt(curCol))) + case BooleanType => + row.addColumnValue(ColumnValue.booleanValue(sparkRow.getBoolean(curCol))) + case DoubleType => + row.addColumnValue(ColumnValue.doubleValue(sparkRow.getDouble(curCol))) + case FloatType => + row.addColumnValue(ColumnValue.floatValue(sparkRow.getFloat(curCol))) + case DecimalType => + val hiveDecimal = sparkRow.get(curCol).asInstanceOf[BigDecimal].bigDecimal + row.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) + case LongType => + row.addColumnValue(ColumnValue.longValue(sparkRow.getLong(curCol))) + case ByteType => + row.addColumnValue(ColumnValue.byteValue(sparkRow.getByte(curCol))) + case ShortType => + row.addColumnValue(ColumnValue.intValue(sparkRow.getShort(curCol))) + case TimestampType => + row.addColumnValue( + ColumnValue.timestampValue(sparkRow.get(curCol).asInstanceOf[Timestamp])) + case BinaryType | _: ArrayType | _: StructType | _: MapType => + val hiveString = result + .queryExecution + .asInstanceOf[HiveContext#QueryExecution] + .toHiveString((sparkRow.get(curCol), dataTypes(curCol))) + row.addColumnValue(ColumnValue.stringValue(hiveString)) + } + curCol += 1 + } + rowSet += row + curRow += 1 + } + new RowSet(rowSet, 0) + } + } + + def getResultSetSchema: TableSchema = { + logger.warn(s"Result Schema: ${result.queryExecution.analyzed.output}") + if (result.queryExecution.analyzed.output.size == 0) { + new TableSchema(new FieldSchema("Result", "string", "") :: Nil) + } else { + val schema = result.queryExecution.analyzed.output.map { attr => + new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") + } + new TableSchema(schema) + } + } + + def run(): Unit = { + logger.info(s"Running query '$statement'") + setState(OperationState.RUNNING) + try { + result = hiveContext.hql(statement) + logger.debug(result.queryExecution.toString()) + val groupId = round(random * 1000000).toString + hiveContext.sparkContext.setJobGroup(groupId, statement) + iter = result.queryExecution.toRdd.toLocalIterator + dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray + setHasResultSet(true) + } catch { + // Actually do need to catch Throwable as some failures don't inherit from Exception and + // HiveServer will silently swallow them. + case e: Throwable => + logger.error("Error executing query:",e) + throw new HiveSQLException(e.toString) + } + setState(OperationState.FINISHED) + } + } + + handleToOperation.put(operation.getHandle, operation) + operation + } +} diff --git a/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt b/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt new file mode 100644 index 0000000000..850f8014b6 --- /dev/null +++ b/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt @@ -0,0 +1,5 @@ +238val_238 +86val_86 +311val_311 +27val_27 +165val_165 diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala new file mode 100644 index 0000000000..b90670a796 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.io.{BufferedReader, InputStreamReader, PrintWriter} + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.sql.hive.test.TestHive + +class CliSuite extends FunSuite with BeforeAndAfterAll with TestUtils { + val WAREHOUSE_PATH = TestUtils.getWarehousePath("cli") + val METASTORE_PATH = TestUtils.getMetastorePath("cli") + + override def beforeAll() { + val pb = new ProcessBuilder( + "../../bin/spark-sql", + "--master", + "local", + "--hiveconf", + s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true", + "--hiveconf", + "hive.metastore.warehouse.dir=" + WAREHOUSE_PATH) + + process = pb.start() + outputWriter = new PrintWriter(process.getOutputStream, true) + inputReader = new BufferedReader(new InputStreamReader(process.getInputStream)) + errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream)) + waitForOutput(inputReader, "spark-sql>") + } + + override def afterAll() { + process.destroy() + process.waitFor() + } + + test("simple commands") { + val dataFilePath = getDataFile("data/files/small_kv.txt") + executeQuery("create table hive_test1(key int, val string);") + executeQuery("load data local inpath '" + dataFilePath+ "' overwrite into table hive_test1;") + executeQuery("cache table hive_test1", "Time taken") + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala new file mode 100644 index 0000000000..59f4952b78 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent._ + +import java.io.{BufferedReader, InputStreamReader} +import java.sql.{Connection, DriverManager, Statement} + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.sql.Logging +import org.apache.spark.sql.catalyst.util.getTempFilePath + +/** + * Test for the HiveThriftServer2 using JDBC. + */ +class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUtils with Logging { + + val WAREHOUSE_PATH = getTempFilePath("warehouse") + val METASTORE_PATH = getTempFilePath("metastore") + + val DRIVER_NAME = "org.apache.hive.jdbc.HiveDriver" + val TABLE = "test" + // use a different port, than the hive standard 10000, + // for tests to avoid issues with the port being taken on some machines + val PORT = "10000" + + // If verbose is true, the test program will print all outputs coming from the Hive Thrift server. + val VERBOSE = Option(System.getenv("SPARK_SQL_TEST_VERBOSE")).getOrElse("false").toBoolean + + Class.forName(DRIVER_NAME) + + override def beforeAll() { launchServer() } + + override def afterAll() { stopServer() } + + private def launchServer(args: Seq[String] = Seq.empty) { + // Forking a new process to start the Hive Thrift server. The reason to do this is it is + // hard to clean up Hive resources entirely, so we just start a new process and kill + // that process for cleanup. + val defaultArgs = Seq( + "../../sbin/start-thriftserver.sh", + "--master local", + "--hiveconf", + "hive.root.logger=INFO,console", + "--hiveconf", + s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true", + "--hiveconf", + s"hive.metastore.warehouse.dir=$WAREHOUSE_PATH") + val pb = new ProcessBuilder(defaultArgs ++ args) + process = pb.start() + inputReader = new BufferedReader(new InputStreamReader(process.getInputStream)) + errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream)) + waitForOutput(inputReader, "ThriftBinaryCLIService listening on") + + // Spawn a thread to read the output from the forked process. + // Note that this is necessary since in some configurations, log4j could be blocked + // if its output to stderr are not read, and eventually blocking the entire test suite. + future { + while (true) { + val stdout = readFrom(inputReader) + val stderr = readFrom(errorReader) + if (VERBOSE && stdout.length > 0) { + println(stdout) + } + if (VERBOSE && stderr.length > 0) { + println(stderr) + } + Thread.sleep(50) + } + } + } + + private def stopServer() { + process.destroy() + process.waitFor() + } + + test("test query execution against a Hive Thrift server") { + Thread.sleep(5 * 1000) + val dataFilePath = getDataFile("data/files/small_kv.txt") + val stmt = createStatement() + stmt.execute("DROP TABLE IF EXISTS test") + stmt.execute("DROP TABLE IF EXISTS test_cached") + stmt.execute("CREATE TABLE test(key int, val string)") + stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test") + stmt.execute("CREATE TABLE test_cached as select * from test limit 4") + stmt.execute("CACHE TABLE test_cached") + + var rs = stmt.executeQuery("select count(*) from test") + rs.next() + assert(rs.getInt(1) === 5) + + rs = stmt.executeQuery("select count(*) from test_cached") + rs.next() + assert(rs.getInt(1) === 4) + + stmt.close() + } + + def getConnection: Connection = { + val connectURI = s"jdbc:hive2://localhost:$PORT/" + DriverManager.getConnection(connectURI, System.getProperty("user.name"), "") + } + + def createStatement(): Statement = getConnection.createStatement() +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala new file mode 100644 index 0000000000..bb2242618f --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.io.{BufferedReader, PrintWriter} +import java.text.SimpleDateFormat +import java.util.Date + +import org.apache.hadoop.hive.common.LogUtils +import org.apache.hadoop.hive.common.LogUtils.LogInitializationException + +object TestUtils { + val timestamp = new SimpleDateFormat("yyyyMMdd-HHmmss") + + def getWarehousePath(prefix: String): String = { + System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-warehouse-" + + timestamp.format(new Date) + } + + def getMetastorePath(prefix: String): String = { + System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-metastore-" + + timestamp.format(new Date) + } + + // Dummy function for initialize the log4j properties. + def init() { } + + // initialize log4j + try { + LogUtils.initHiveLog4j() + } catch { + case e: LogInitializationException => // Ignore the error. + } +} + +trait TestUtils { + var process : Process = null + var outputWriter : PrintWriter = null + var inputReader : BufferedReader = null + var errorReader : BufferedReader = null + + def executeQuery( + cmd: String, outputMessage: String = "OK", timeout: Long = 15000): String = { + println("Executing: " + cmd + ", expecting output: " + outputMessage) + outputWriter.write(cmd + "\n") + outputWriter.flush() + waitForQuery(timeout, outputMessage) + } + + protected def waitForQuery(timeout: Long, message: String): String = { + if (waitForOutput(errorReader, message, timeout)) { + Thread.sleep(500) + readOutput() + } else { + assert(false, "Didn't find \"" + message + "\" in the output:\n" + readOutput()) + null + } + } + + // Wait for the specified str to appear in the output. + protected def waitForOutput( + reader: BufferedReader, str: String, timeout: Long = 10000): Boolean = { + val startTime = System.currentTimeMillis + var out = "" + while (!out.contains(str) && System.currentTimeMillis < (startTime + timeout)) { + out += readFrom(reader) + } + out.contains(str) + } + + // Read stdout output and filter out garbage collection messages. + protected def readOutput(): String = { + val output = readFrom(inputReader) + // Remove GC Messages + val filteredOutput = output.lines.filterNot(x => x.contains("[GC") || x.contains("[Full GC")) + .mkString("\n") + filteredOutput + } + + protected def readFrom(reader: BufferedReader): String = { + var out = "" + var c = 0 + while (reader.ready) { + c = reader.read() + out += c.asInstanceOf[Char] + } + out + } + + protected def getDataFile(name: String) = { + Thread.currentThread().getContextClassLoader.getResource(name) + } +} |