aboutsummaryrefslogtreecommitdiff
path: root/sql/hive-thriftserver
diff options
context:
space:
mode:
authorCheng Lian <lian.cs.zju@gmail.com>2014-07-25 12:20:49 -0700
committerMichael Armbrust <michael@databricks.com>2014-07-25 12:20:49 -0700
commit06dc0d2c6b69c5d59b4d194ced2ac85bfe2e05e2 (patch)
tree5f3189e690ac7f2fe68426c6763f7900e9aba5b6 /sql/hive-thriftserver
parent32bcf9af94b39f2c509eb54f8565fb659c70ca97 (diff)
downloadspark-06dc0d2c6b69c5d59b4d194ced2ac85bfe2e05e2.tar.gz
spark-06dc0d2c6b69c5d59b4d194ced2ac85bfe2e05e2.tar.bz2
spark-06dc0d2c6b69c5d59b4d194ced2ac85bfe2e05e2.zip
[SPARK-2410][SQL] Merging Hive Thrift/JDBC server
JIRA issue: - Main: [SPARK-2410](https://issues.apache.org/jira/browse/SPARK-2410) - Related: [SPARK-2678](https://issues.apache.org/jira/browse/SPARK-2678) Cherry picked the Hive Thrift/JDBC server from [branch-1.0-jdbc](https://github.com/apache/spark/tree/branch-1.0-jdbc). (Thanks chenghao-intel for his initial contribution of the Spark SQL CLI.) TODO - [x] Use `spark-submit` to launch the server, the CLI and beeline - [x] Migration guideline draft for Shark users ---- Hit by a bug in `SparkSubmitArguments` while working on this PR: all application options that are recognized by `SparkSubmitArguments` are stolen as `SparkSubmit` options. For example: ```bash $ spark-submit --class org.apache.hive.beeline.BeeLine spark-internal --help ``` This actually shows usage information of `SparkSubmit` rather than `BeeLine`. ~~Fixed this bug here since the `spark-internal` related stuff also touches `SparkSubmitArguments` and I'd like to avoid conflict.~~ **UPDATE** The bug mentioned above is now tracked by [SPARK-2678](https://issues.apache.org/jira/browse/SPARK-2678). Decided to revert changes to this bug since it involves more subtle considerations and worth a separate PR. Author: Cheng Lian <lian.cs.zju@gmail.com> Closes #1399 from liancheng/thriftserver and squashes the following commits: 090beea [Cheng Lian] Revert changes related to SPARK-2678, decided to move them to another PR 21c6cf4 [Cheng Lian] Updated Spark SQL programming guide docs fe0af31 [Cheng Lian] Reordered spark-submit options in spark-shell[.cmd] 199e3fb [Cheng Lian] Disabled MIMA for hive-thriftserver 1083e9d [Cheng Lian] Fixed failed test suites 7db82a1 [Cheng Lian] Fixed spark-submit application options handling logic 9cc0f06 [Cheng Lian] Starts beeline with spark-submit cfcf461 [Cheng Lian] Updated documents and build scripts for the newly added hive-thriftserver profile 061880f [Cheng Lian] Addressed all comments by @pwendell 7755062 [Cheng Lian] Adapts test suites to spark-submit settings 40bafef [Cheng Lian] Fixed more license header issues e214aab [Cheng Lian] Added missing license headers b8905ba [Cheng Lian] Fixed minor issues in spark-sql and start-thriftserver.sh f975d22 [Cheng Lian] Updated docs for Hive compatibility and Shark migration guide draft 3ad4e75 [Cheng Lian] Starts spark-sql shell with spark-submit a5310d1 [Cheng Lian] Make HiveThriftServer2 play well with spark-submit 61f39f4 [Cheng Lian] Starts Hive Thrift server via spark-submit 2c4c539 [Cheng Lian] Cherry picked the Hive Thrift server
Diffstat (limited to 'sql/hive-thriftserver')
-rw-r--r--sql/hive-thriftserver/pom.xml82
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala97
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala58
-rwxr-xr-xsql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala344
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala74
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala93
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala58
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala49
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala151
-rw-r--r--sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt5
-rw-r--r--sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala59
-rw-r--r--sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala125
-rw-r--r--sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala108
13 files changed, 1303 insertions, 0 deletions
diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml
new file mode 100644
index 0000000000..7fac90fdc5
--- /dev/null
+++ b/sql/hive-thriftserver/pom.xml
@@ -0,0 +1,82 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-parent</artifactId>
+ <version>1.1.0-SNAPSHOT</version>
+ <relativePath>../../pom.xml</relativePath>
+ </parent>
+
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-hive-thriftserver_2.10</artifactId>
+ <packaging>jar</packaging>
+ <name>Spark Project Hive</name>
+ <url>http://spark.apache.org/</url>
+ <properties>
+ <sbt.project.name>hive-thriftserver</sbt.project.name>
+ </properties>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-hive_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.spark-project.hive</groupId>
+ <artifactId>hive-cli</artifactId>
+ <version>${hive.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.spark-project.hive</groupId>
+ <artifactId>hive-jdbc</artifactId>
+ <version>${hive.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.spark-project.hive</groupId>
+ <artifactId>hive-beeline</artifactId>
+ <version>${hive.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+ <build>
+ <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
+ <plugins>
+ <plugin>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest-maven-plugin</artifactId>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-deploy-plugin</artifactId>
+ <configuration>
+ <skip>true</skip>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+</project>
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
new file mode 100644
index 0000000000..ddbc2a79fb
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import scala.collection.JavaConversions._
+
+import org.apache.commons.logging.LogFactory
+import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.ql.session.SessionState
+import org.apache.hive.service.cli.thrift.ThriftBinaryCLIService
+import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor}
+
+import org.apache.spark.sql.Logging
+import org.apache.spark.sql.hive.HiveContext
+import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
+
+/**
+ * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a
+ * `HiveThriftServer2` thrift server.
+ */
+private[hive] object HiveThriftServer2 extends Logging {
+ var LOG = LogFactory.getLog(classOf[HiveServer2])
+
+ def main(args: Array[String]) {
+ val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2")
+
+ if (!optionsProcessor.process(args)) {
+ logger.warn("Error starting HiveThriftServer2 with given arguments")
+ System.exit(-1)
+ }
+
+ val ss = new SessionState(new HiveConf(classOf[SessionState]))
+
+ // Set all properties specified via command line.
+ val hiveConf: HiveConf = ss.getConf
+ hiveConf.getAllProperties.toSeq.sortBy(_._1).foreach { case (k, v) =>
+ logger.debug(s"HiveConf var: $k=$v")
+ }
+
+ SessionState.start(ss)
+
+ logger.info("Starting SparkContext")
+ SparkSQLEnv.init()
+ SessionState.start(ss)
+
+ Runtime.getRuntime.addShutdownHook(
+ new Thread() {
+ override def run() {
+ SparkSQLEnv.sparkContext.stop()
+ }
+ }
+ )
+
+ try {
+ val server = new HiveThriftServer2(SparkSQLEnv.hiveContext)
+ server.init(hiveConf)
+ server.start()
+ logger.info("HiveThriftServer2 started")
+ } catch {
+ case e: Exception =>
+ logger.error("Error starting HiveThriftServer2", e)
+ System.exit(-1)
+ }
+ }
+}
+
+private[hive] class HiveThriftServer2(hiveContext: HiveContext)
+ extends HiveServer2
+ with ReflectedCompositeService {
+
+ override def init(hiveConf: HiveConf) {
+ val sparkSqlCliService = new SparkSQLCLIService(hiveContext)
+ setSuperField(this, "cliService", sparkSqlCliService)
+ addService(sparkSqlCliService)
+
+ val thriftCliService = new ThriftBinaryCLIService(sparkSqlCliService)
+ setSuperField(this, "thriftCLIService", thriftCliService)
+ addService(thriftCliService)
+
+ initCompositeService(hiveConf)
+ }
+}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala
new file mode 100644
index 0000000000..599294dfbb
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+private[hive] object ReflectionUtils {
+ def setSuperField(obj : Object, fieldName: String, fieldValue: Object) {
+ setAncestorField(obj, 1, fieldName, fieldValue)
+ }
+
+ def setAncestorField(obj: AnyRef, level: Int, fieldName: String, fieldValue: AnyRef) {
+ val ancestor = Iterator.iterate[Class[_]](obj.getClass)(_.getSuperclass).drop(level).next()
+ val field = ancestor.getDeclaredField(fieldName)
+ field.setAccessible(true)
+ field.set(obj, fieldValue)
+ }
+
+ def getSuperField[T](obj: AnyRef, fieldName: String): T = {
+ getAncestorField[T](obj, 1, fieldName)
+ }
+
+ def getAncestorField[T](clazz: Object, level: Int, fieldName: String): T = {
+ val ancestor = Iterator.iterate[Class[_]](clazz.getClass)(_.getSuperclass).drop(level).next()
+ val field = ancestor.getDeclaredField(fieldName)
+ field.setAccessible(true)
+ field.get(clazz).asInstanceOf[T]
+ }
+
+ def invokeStatic(clazz: Class[_], methodName: String, args: (Class[_], AnyRef)*): AnyRef = {
+ invoke(clazz, null, methodName, args: _*)
+ }
+
+ def invoke(
+ clazz: Class[_],
+ obj: AnyRef,
+ methodName: String,
+ args: (Class[_], AnyRef)*): AnyRef = {
+
+ val (types, values) = args.unzip
+ val method = clazz.getDeclaredMethod(methodName, types: _*)
+ method.setAccessible(true)
+ method.invoke(obj, values.toSeq: _*)
+ }
+}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
new file mode 100755
index 0000000000..27268ecb92
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import scala.collection.JavaConversions._
+
+import java.io._
+import java.util.{ArrayList => JArrayList}
+
+import jline.{ConsoleReader, History}
+import org.apache.commons.lang.StringUtils
+import org.apache.commons.logging.LogFactory
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor}
+import org.apache.hadoop.hive.common.LogUtils.LogInitializationException
+import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, LogUtils}
+import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.ql.Driver
+import org.apache.hadoop.hive.ql.exec.Utilities
+import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory}
+import org.apache.hadoop.hive.ql.session.SessionState
+import org.apache.hadoop.hive.shims.ShimLoader
+import org.apache.thrift.transport.TSocket
+
+import org.apache.spark.sql.Logging
+
+private[hive] object SparkSQLCLIDriver {
+ private var prompt = "spark-sql"
+ private var continuedPrompt = "".padTo(prompt.length, ' ')
+ private var transport:TSocket = _
+
+ installSignalHandler()
+
+ /**
+ * Install an interrupt callback to cancel all Spark jobs. In Hive's CliDriver#processLine(),
+ * a signal handler will invoke this registered callback if a Ctrl+C signal is detected while
+ * a command is being processed by the current thread.
+ */
+ def installSignalHandler() {
+ HiveInterruptUtils.add(new HiveInterruptCallback {
+ override def interrupt() {
+ // Handle remote execution mode
+ if (SparkSQLEnv.sparkContext != null) {
+ SparkSQLEnv.sparkContext.cancelAllJobs()
+ } else {
+ if (transport != null) {
+ // Force closing of TCP connection upon session termination
+ transport.getSocket.close()
+ }
+ }
+ }
+ })
+ }
+
+ def main(args: Array[String]) {
+ val oproc = new OptionsProcessor()
+ if (!oproc.process_stage1(args)) {
+ System.exit(1)
+ }
+
+ // NOTE: It is critical to do this here so that log4j is reinitialized
+ // before any of the other core hive classes are loaded
+ var logInitFailed = false
+ var logInitDetailMessage: String = null
+ try {
+ logInitDetailMessage = LogUtils.initHiveLog4j()
+ } catch {
+ case e: LogInitializationException =>
+ logInitFailed = true
+ logInitDetailMessage = e.getMessage
+ }
+
+ val sessionState = new CliSessionState(new HiveConf(classOf[SessionState]))
+
+ sessionState.in = System.in
+ try {
+ sessionState.out = new PrintStream(System.out, true, "UTF-8")
+ sessionState.info = new PrintStream(System.err, true, "UTF-8")
+ sessionState.err = new PrintStream(System.err, true, "UTF-8")
+ } catch {
+ case e: UnsupportedEncodingException => System.exit(3)
+ }
+
+ if (!oproc.process_stage2(sessionState)) {
+ System.exit(2)
+ }
+
+ if (!sessionState.getIsSilent) {
+ if (logInitFailed) System.err.println(logInitDetailMessage)
+ else SessionState.getConsole.printInfo(logInitDetailMessage)
+ }
+
+ // Set all properties specified via command line.
+ val conf: HiveConf = sessionState.getConf
+ sessionState.cmdProperties.entrySet().foreach { item: java.util.Map.Entry[Object, Object] =>
+ conf.set(item.getKey.asInstanceOf[String], item.getValue.asInstanceOf[String])
+ sessionState.getOverriddenConfigurations.put(
+ item.getKey.asInstanceOf[String], item.getValue.asInstanceOf[String])
+ }
+
+ SessionState.start(sessionState)
+
+ // Clean up after we exit
+ Runtime.getRuntime.addShutdownHook(
+ new Thread() {
+ override def run() {
+ SparkSQLEnv.stop()
+ }
+ }
+ )
+
+ // "-h" option has been passed, so connect to Hive thrift server.
+ if (sessionState.getHost != null) {
+ sessionState.connect()
+ if (sessionState.isRemoteMode) {
+ prompt = s"[${sessionState.getHost}:${sessionState.getPort}]" + prompt
+ continuedPrompt = "".padTo(prompt.length, ' ')
+ }
+ }
+
+ if (!sessionState.isRemoteMode && !ShimLoader.getHadoopShims.usesJobShell()) {
+ // Hadoop-20 and above - we need to augment classpath using hiveconf
+ // components.
+ // See also: code in ExecDriver.java
+ var loader = conf.getClassLoader
+ val auxJars = HiveConf.getVar(conf, HiveConf.ConfVars.HIVEAUXJARS)
+ if (StringUtils.isNotBlank(auxJars)) {
+ loader = Utilities.addToClassPath(loader, StringUtils.split(auxJars, ","))
+ }
+ conf.setClassLoader(loader)
+ Thread.currentThread().setContextClassLoader(loader)
+ }
+
+ val cli = new SparkSQLCLIDriver
+ cli.setHiveVariables(oproc.getHiveVariables)
+
+ // TODO work around for set the log output to console, because the HiveContext
+ // will set the output into an invalid buffer.
+ sessionState.in = System.in
+ try {
+ sessionState.out = new PrintStream(System.out, true, "UTF-8")
+ sessionState.info = new PrintStream(System.err, true, "UTF-8")
+ sessionState.err = new PrintStream(System.err, true, "UTF-8")
+ } catch {
+ case e: UnsupportedEncodingException => System.exit(3)
+ }
+
+ // Execute -i init files (always in silent mode)
+ cli.processInitFiles(sessionState)
+
+ if (sessionState.execString != null) {
+ System.exit(cli.processLine(sessionState.execString))
+ }
+
+ try {
+ if (sessionState.fileName != null) {
+ System.exit(cli.processFile(sessionState.fileName))
+ }
+ } catch {
+ case e: FileNotFoundException =>
+ System.err.println(s"Could not open input file for reading. (${e.getMessage})")
+ System.exit(3)
+ }
+
+ val reader = new ConsoleReader()
+ reader.setBellEnabled(false)
+ // reader.setDebug(new PrintWriter(new FileWriter("writer.debug", true)))
+ CliDriver.getCommandCompletor.foreach((e) => reader.addCompletor(e))
+
+ val historyDirectory = System.getProperty("user.home")
+
+ try {
+ if (new File(historyDirectory).exists()) {
+ val historyFile = historyDirectory + File.separator + ".hivehistory"
+ reader.setHistory(new History(new File(historyFile)))
+ } else {
+ System.err.println("WARNING: Directory for Hive history file: " + historyDirectory +
+ " does not exist. History will not be available during this session.")
+ }
+ } catch {
+ case e: Exception =>
+ System.err.println("WARNING: Encountered an error while trying to initialize Hive's " +
+ "history file. History will not be available during this session.")
+ System.err.println(e.getMessage)
+ }
+
+ val clientTransportTSocketField = classOf[CliSessionState].getDeclaredField("transport")
+ clientTransportTSocketField.setAccessible(true)
+
+ transport = clientTransportTSocketField.get(sessionState).asInstanceOf[TSocket]
+
+ var ret = 0
+ var prefix = ""
+ val currentDB = ReflectionUtils.invokeStatic(classOf[CliDriver], "getFormattedDb",
+ classOf[HiveConf] -> conf, classOf[CliSessionState] -> sessionState)
+
+ def promptWithCurrentDB = s"$prompt$currentDB"
+ def continuedPromptWithDBSpaces = continuedPrompt + ReflectionUtils.invokeStatic(
+ classOf[CliDriver], "spacesForString", classOf[String] -> currentDB)
+
+ var currentPrompt = promptWithCurrentDB
+ var line = reader.readLine(currentPrompt + "> ")
+
+ while (line != null) {
+ if (prefix.nonEmpty) {
+ prefix += '\n'
+ }
+
+ if (line.trim().endsWith(";") && !line.trim().endsWith("\\;")) {
+ line = prefix + line
+ ret = cli.processLine(line, true)
+ prefix = ""
+ currentPrompt = promptWithCurrentDB
+ } else {
+ prefix = prefix + line
+ currentPrompt = continuedPromptWithDBSpaces
+ }
+
+ line = reader.readLine(currentPrompt + "> ")
+ }
+
+ sessionState.close()
+
+ System.exit(ret)
+ }
+}
+
+private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
+ private val sessionState = SessionState.get().asInstanceOf[CliSessionState]
+
+ private val LOG = LogFactory.getLog("CliDriver")
+
+ private val console = new SessionState.LogHelper(LOG)
+
+ private val conf: Configuration =
+ if (sessionState != null) sessionState.getConf else new Configuration()
+
+ // Force initializing SparkSQLEnv. This is put here but not object SparkSQLCliDriver
+ // because the Hive unit tests do not go through the main() code path.
+ if (!sessionState.isRemoteMode) {
+ SparkSQLEnv.init()
+ }
+
+ override def processCmd(cmd: String): Int = {
+ val cmd_trimmed: String = cmd.trim()
+ val tokens: Array[String] = cmd_trimmed.split("\\s+")
+ val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim()
+ if (cmd_trimmed.toLowerCase.equals("quit") ||
+ cmd_trimmed.toLowerCase.equals("exit") ||
+ tokens(0).equalsIgnoreCase("source") ||
+ cmd_trimmed.startsWith("!") ||
+ tokens(0).toLowerCase.equals("list") ||
+ sessionState.isRemoteMode) {
+ val start = System.currentTimeMillis()
+ super.processCmd(cmd)
+ val end = System.currentTimeMillis()
+ val timeTaken: Double = (end - start) / 1000.0
+ console.printInfo(s"Time taken: $timeTaken seconds")
+ 0
+ } else {
+ var ret = 0
+ val hconf = conf.asInstanceOf[HiveConf]
+ val proc: CommandProcessor = CommandProcessorFactory.get(tokens(0), hconf)
+
+ if (proc != null) {
+ if (proc.isInstanceOf[Driver]) {
+ val driver = new SparkSQLDriver
+
+ driver.init()
+ val out = sessionState.out
+ val start:Long = System.currentTimeMillis()
+ if (sessionState.getIsVerbose) {
+ out.println(cmd)
+ }
+
+ ret = driver.run(cmd).getResponseCode
+ if (ret != 0) {
+ driver.close()
+ return ret
+ }
+
+ val res = new JArrayList[String]()
+
+ if (HiveConf.getBoolVar(conf, HiveConf.ConfVars.HIVE_CLI_PRINT_HEADER)) {
+ // Print the column names.
+ Option(driver.getSchema.getFieldSchemas).map { fields =>
+ out.println(fields.map(_.getName).mkString("\t"))
+ }
+ }
+
+ try {
+ while (!out.checkError() && driver.getResults(res)) {
+ res.foreach(out.println)
+ res.clear()
+ }
+ } catch {
+ case e:IOException =>
+ console.printError(
+ s"""Failed with exception ${e.getClass.getName}: ${e.getMessage}
+ |${org.apache.hadoop.util.StringUtils.stringifyException(e)}
+ """.stripMargin)
+ ret = 1
+ }
+
+ val cret = driver.close()
+ if (ret == 0) {
+ ret = cret
+ }
+
+ val end = System.currentTimeMillis()
+ if (end > start) {
+ val timeTaken:Double = (end - start) / 1000.0
+ console.printInfo(s"Time taken: $timeTaken seconds", null)
+ }
+
+ // Destroy the driver to release all the locks.
+ driver.destroy()
+ } else {
+ if (sessionState.getIsVerbose) {
+ sessionState.out.println(tokens(0) + " " + cmd_1)
+ }
+ ret = proc.run(cmd_1).getResponseCode
+ }
+ }
+ ret
+ }
+ }
+}
+
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala
new file mode 100644
index 0000000000..42cbf363b2
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import scala.collection.JavaConversions._
+
+import java.io.IOException
+import java.util.{List => JList}
+import javax.security.auth.login.LoginException
+
+import org.apache.commons.logging.Log
+import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.shims.ShimLoader
+import org.apache.hive.service.Service.STATE
+import org.apache.hive.service.auth.HiveAuthFactory
+import org.apache.hive.service.cli.CLIService
+import org.apache.hive.service.{AbstractService, Service, ServiceException}
+
+import org.apache.spark.sql.hive.HiveContext
+import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
+
+private[hive] class SparkSQLCLIService(hiveContext: HiveContext)
+ extends CLIService
+ with ReflectedCompositeService {
+
+ override def init(hiveConf: HiveConf) {
+ setSuperField(this, "hiveConf", hiveConf)
+
+ val sparkSqlSessionManager = new SparkSQLSessionManager(hiveContext)
+ setSuperField(this, "sessionManager", sparkSqlSessionManager)
+ addService(sparkSqlSessionManager)
+
+ try {
+ HiveAuthFactory.loginFromKeytab(hiveConf)
+ val serverUserName = ShimLoader.getHadoopShims
+ .getShortUserName(ShimLoader.getHadoopShims.getUGIForConf(hiveConf))
+ setSuperField(this, "serverUserName", serverUserName)
+ } catch {
+ case e @ (_: IOException | _: LoginException) =>
+ throw new ServiceException("Unable to login to kerberos with given principal/keytab", e)
+ }
+
+ initCompositeService(hiveConf)
+ }
+}
+
+private[thriftserver] trait ReflectedCompositeService { this: AbstractService =>
+ def initCompositeService(hiveConf: HiveConf) {
+ // Emulating `CompositeService.init(hiveConf)`
+ val serviceList = getAncestorField[JList[Service]](this, 2, "serviceList")
+ serviceList.foreach(_.init(hiveConf))
+
+ // Emulating `AbstractService.init(hiveConf)`
+ invoke(classOf[AbstractService], this, "ensureCurrentState", classOf[STATE] -> STATE.NOTINITED)
+ setAncestorField(this, 3, "hiveConf", hiveConf)
+ invoke(classOf[AbstractService], this, "changeState", classOf[STATE] -> STATE.INITED)
+ getAncestorField[Log](this, 3, "LOG").info(s"Service: $getName is inited.")
+ }
+}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala
new file mode 100644
index 0000000000..5202aa9903
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import scala.collection.JavaConversions._
+
+import java.util.{ArrayList => JArrayList}
+
+import org.apache.commons.lang.exception.ExceptionUtils
+import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema}
+import org.apache.hadoop.hive.ql.Driver
+import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse
+
+import org.apache.spark.sql.Logging
+import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes}
+
+private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveContext)
+ extends Driver with Logging {
+
+ private var tableSchema: Schema = _
+ private var hiveResponse: Seq[String] = _
+
+ override def init(): Unit = {
+ }
+
+ private def getResultSetSchema(query: context.QueryExecution): Schema = {
+ val analyzed = query.analyzed
+ logger.debug(s"Result Schema: ${analyzed.output}")
+ if (analyzed.output.size == 0) {
+ new Schema(new FieldSchema("Response code", "string", "") :: Nil, null)
+ } else {
+ val fieldSchemas = analyzed.output.map { attr =>
+ new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "")
+ }
+
+ new Schema(fieldSchemas, null)
+ }
+ }
+
+ override def run(command: String): CommandProcessorResponse = {
+ val execution = context.executePlan(context.hql(command).logicalPlan)
+
+ // TODO unify the error code
+ try {
+ hiveResponse = execution.stringResult()
+ tableSchema = getResultSetSchema(execution)
+ new CommandProcessorResponse(0)
+ } catch {
+ case cause: Throwable =>
+ logger.error(s"Failed in [$command]", cause)
+ new CommandProcessorResponse(-3, ExceptionUtils.getFullStackTrace(cause), null)
+ }
+ }
+
+ override def close(): Int = {
+ hiveResponse = null
+ tableSchema = null
+ 0
+ }
+
+ override def getSchema: Schema = tableSchema
+
+ override def getResults(res: JArrayList[String]): Boolean = {
+ if (hiveResponse == null) {
+ false
+ } else {
+ res.addAll(hiveResponse)
+ hiveResponse = null
+ true
+ }
+ }
+
+ override def destroy() {
+ super.destroy()
+ hiveResponse = null
+ tableSchema = null
+ }
+}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
new file mode 100644
index 0000000000..451c3bd7b9
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import org.apache.hadoop.hive.ql.session.SessionState
+
+import org.apache.spark.scheduler.{SplitInfo, StatsReportListener}
+import org.apache.spark.sql.Logging
+import org.apache.spark.sql.hive.HiveContext
+import org.apache.spark.{SparkConf, SparkContext}
+
+/** A singleton object for the master program. The slaves should not access this. */
+private[hive] object SparkSQLEnv extends Logging {
+ logger.debug("Initializing SparkSQLEnv")
+
+ var hiveContext: HiveContext = _
+ var sparkContext: SparkContext = _
+
+ def init() {
+ if (hiveContext == null) {
+ sparkContext = new SparkContext(new SparkConf()
+ .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}"))
+
+ sparkContext.addSparkListener(new StatsReportListener())
+
+ hiveContext = new HiveContext(sparkContext) {
+ @transient override lazy val sessionState = SessionState.get()
+ @transient override lazy val hiveconf = sessionState.getConf
+ }
+ }
+ }
+
+ /** Cleans up and shuts down the Spark SQL environments. */
+ def stop() {
+ logger.debug("Shutting down Spark SQL Environment")
+ // Stop the SparkContext
+ if (SparkSQLEnv.sparkContext != null) {
+ sparkContext.stop()
+ sparkContext = null
+ hiveContext = null
+ }
+ }
+}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala
new file mode 100644
index 0000000000..6b3275b4ea
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import java.util.concurrent.Executors
+
+import org.apache.commons.logging.Log
+import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.conf.HiveConf.ConfVars
+import org.apache.hive.service.cli.session.SessionManager
+
+import org.apache.spark.sql.hive.HiveContext
+import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
+import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager
+
+private[hive] class SparkSQLSessionManager(hiveContext: HiveContext)
+ extends SessionManager
+ with ReflectedCompositeService {
+
+ override def init(hiveConf: HiveConf) {
+ setSuperField(this, "hiveConf", hiveConf)
+
+ val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS)
+ setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize))
+ getAncestorField[Log](this, 3, "LOG").info(
+ s"HiveServer2: Async execution pool size $backgroundPoolSize")
+
+ val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext)
+ setSuperField(this, "operationManager", sparkSqlOperationManager)
+ addService(sparkSqlOperationManager)
+
+ initCompositeService(hiveConf)
+ }
+}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
new file mode 100644
index 0000000000..a4e1f3e762
--- /dev/null
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver.server
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
+import scala.math.{random, round}
+
+import java.sql.Timestamp
+import java.util.{Map => JMap}
+
+import org.apache.hadoop.hive.common.`type`.HiveDecimal
+import org.apache.hadoop.hive.metastore.api.FieldSchema
+import org.apache.hive.service.cli._
+import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager}
+import org.apache.hive.service.cli.session.HiveSession
+
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.hive.thriftserver.ReflectionUtils
+import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes}
+import org.apache.spark.sql.{Logging, SchemaRDD, Row => SparkRow}
+
+/**
+ * Executes queries using Spark SQL, and maintains a list of handles to active queries.
+ */
+class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManager with Logging {
+ val handleToOperation = ReflectionUtils
+ .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation")
+
+ override def newExecuteStatementOperation(
+ parentSession: HiveSession,
+ statement: String,
+ confOverlay: JMap[String, String],
+ async: Boolean): ExecuteStatementOperation = synchronized {
+
+ val operation = new ExecuteStatementOperation(parentSession, statement, confOverlay) {
+ private var result: SchemaRDD = _
+ private var iter: Iterator[SparkRow] = _
+ private var dataTypes: Array[DataType] = _
+
+ def close(): Unit = {
+ // RDDs will be cleaned automatically upon garbage collection.
+ logger.debug("CLOSING")
+ }
+
+ def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = {
+ if (!iter.hasNext) {
+ new RowSet()
+ } else {
+ val maxRows = maxRowsL.toInt // Do you really want a row batch larger than Int Max? No.
+ var curRow = 0
+ var rowSet = new ArrayBuffer[Row](maxRows)
+
+ while (curRow < maxRows && iter.hasNext) {
+ val sparkRow = iter.next()
+ val row = new Row()
+ var curCol = 0
+
+ while (curCol < sparkRow.length) {
+ dataTypes(curCol) match {
+ case StringType =>
+ row.addString(sparkRow(curCol).asInstanceOf[String])
+ case IntegerType =>
+ row.addColumnValue(ColumnValue.intValue(sparkRow.getInt(curCol)))
+ case BooleanType =>
+ row.addColumnValue(ColumnValue.booleanValue(sparkRow.getBoolean(curCol)))
+ case DoubleType =>
+ row.addColumnValue(ColumnValue.doubleValue(sparkRow.getDouble(curCol)))
+ case FloatType =>
+ row.addColumnValue(ColumnValue.floatValue(sparkRow.getFloat(curCol)))
+ case DecimalType =>
+ val hiveDecimal = sparkRow.get(curCol).asInstanceOf[BigDecimal].bigDecimal
+ row.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal)))
+ case LongType =>
+ row.addColumnValue(ColumnValue.longValue(sparkRow.getLong(curCol)))
+ case ByteType =>
+ row.addColumnValue(ColumnValue.byteValue(sparkRow.getByte(curCol)))
+ case ShortType =>
+ row.addColumnValue(ColumnValue.intValue(sparkRow.getShort(curCol)))
+ case TimestampType =>
+ row.addColumnValue(
+ ColumnValue.timestampValue(sparkRow.get(curCol).asInstanceOf[Timestamp]))
+ case BinaryType | _: ArrayType | _: StructType | _: MapType =>
+ val hiveString = result
+ .queryExecution
+ .asInstanceOf[HiveContext#QueryExecution]
+ .toHiveString((sparkRow.get(curCol), dataTypes(curCol)))
+ row.addColumnValue(ColumnValue.stringValue(hiveString))
+ }
+ curCol += 1
+ }
+ rowSet += row
+ curRow += 1
+ }
+ new RowSet(rowSet, 0)
+ }
+ }
+
+ def getResultSetSchema: TableSchema = {
+ logger.warn(s"Result Schema: ${result.queryExecution.analyzed.output}")
+ if (result.queryExecution.analyzed.output.size == 0) {
+ new TableSchema(new FieldSchema("Result", "string", "") :: Nil)
+ } else {
+ val schema = result.queryExecution.analyzed.output.map { attr =>
+ new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "")
+ }
+ new TableSchema(schema)
+ }
+ }
+
+ def run(): Unit = {
+ logger.info(s"Running query '$statement'")
+ setState(OperationState.RUNNING)
+ try {
+ result = hiveContext.hql(statement)
+ logger.debug(result.queryExecution.toString())
+ val groupId = round(random * 1000000).toString
+ hiveContext.sparkContext.setJobGroup(groupId, statement)
+ iter = result.queryExecution.toRdd.toLocalIterator
+ dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray
+ setHasResultSet(true)
+ } catch {
+ // Actually do need to catch Throwable as some failures don't inherit from Exception and
+ // HiveServer will silently swallow them.
+ case e: Throwable =>
+ logger.error("Error executing query:",e)
+ throw new HiveSQLException(e.toString)
+ }
+ setState(OperationState.FINISHED)
+ }
+ }
+
+ handleToOperation.put(operation.getHandle, operation)
+ operation
+ }
+}
diff --git a/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt b/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt
new file mode 100644
index 0000000000..850f8014b6
--- /dev/null
+++ b/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt
@@ -0,0 +1,5 @@
+238val_238
+86val_86
+311val_311
+27val_27
+165val_165
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
new file mode 100644
index 0000000000..b90670a796
--- /dev/null
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import java.io.{BufferedReader, InputStreamReader, PrintWriter}
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.sql.hive.test.TestHive
+
+class CliSuite extends FunSuite with BeforeAndAfterAll with TestUtils {
+ val WAREHOUSE_PATH = TestUtils.getWarehousePath("cli")
+ val METASTORE_PATH = TestUtils.getMetastorePath("cli")
+
+ override def beforeAll() {
+ val pb = new ProcessBuilder(
+ "../../bin/spark-sql",
+ "--master",
+ "local",
+ "--hiveconf",
+ s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true",
+ "--hiveconf",
+ "hive.metastore.warehouse.dir=" + WAREHOUSE_PATH)
+
+ process = pb.start()
+ outputWriter = new PrintWriter(process.getOutputStream, true)
+ inputReader = new BufferedReader(new InputStreamReader(process.getInputStream))
+ errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream))
+ waitForOutput(inputReader, "spark-sql>")
+ }
+
+ override def afterAll() {
+ process.destroy()
+ process.waitFor()
+ }
+
+ test("simple commands") {
+ val dataFilePath = getDataFile("data/files/small_kv.txt")
+ executeQuery("create table hive_test1(key int, val string);")
+ executeQuery("load data local inpath '" + dataFilePath+ "' overwrite into table hive_test1;")
+ executeQuery("cache table hive_test1", "Time taken")
+ }
+}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
new file mode 100644
index 0000000000..59f4952b78
--- /dev/null
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import scala.collection.JavaConversions._
+import scala.concurrent.ExecutionContext.Implicits.global
+import scala.concurrent._
+
+import java.io.{BufferedReader, InputStreamReader}
+import java.sql.{Connection, DriverManager, Statement}
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.sql.Logging
+import org.apache.spark.sql.catalyst.util.getTempFilePath
+
+/**
+ * Test for the HiveThriftServer2 using JDBC.
+ */
+class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUtils with Logging {
+
+ val WAREHOUSE_PATH = getTempFilePath("warehouse")
+ val METASTORE_PATH = getTempFilePath("metastore")
+
+ val DRIVER_NAME = "org.apache.hive.jdbc.HiveDriver"
+ val TABLE = "test"
+ // use a different port, than the hive standard 10000,
+ // for tests to avoid issues with the port being taken on some machines
+ val PORT = "10000"
+
+ // If verbose is true, the test program will print all outputs coming from the Hive Thrift server.
+ val VERBOSE = Option(System.getenv("SPARK_SQL_TEST_VERBOSE")).getOrElse("false").toBoolean
+
+ Class.forName(DRIVER_NAME)
+
+ override def beforeAll() { launchServer() }
+
+ override def afterAll() { stopServer() }
+
+ private def launchServer(args: Seq[String] = Seq.empty) {
+ // Forking a new process to start the Hive Thrift server. The reason to do this is it is
+ // hard to clean up Hive resources entirely, so we just start a new process and kill
+ // that process for cleanup.
+ val defaultArgs = Seq(
+ "../../sbin/start-thriftserver.sh",
+ "--master local",
+ "--hiveconf",
+ "hive.root.logger=INFO,console",
+ "--hiveconf",
+ s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true",
+ "--hiveconf",
+ s"hive.metastore.warehouse.dir=$WAREHOUSE_PATH")
+ val pb = new ProcessBuilder(defaultArgs ++ args)
+ process = pb.start()
+ inputReader = new BufferedReader(new InputStreamReader(process.getInputStream))
+ errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream))
+ waitForOutput(inputReader, "ThriftBinaryCLIService listening on")
+
+ // Spawn a thread to read the output from the forked process.
+ // Note that this is necessary since in some configurations, log4j could be blocked
+ // if its output to stderr are not read, and eventually blocking the entire test suite.
+ future {
+ while (true) {
+ val stdout = readFrom(inputReader)
+ val stderr = readFrom(errorReader)
+ if (VERBOSE && stdout.length > 0) {
+ println(stdout)
+ }
+ if (VERBOSE && stderr.length > 0) {
+ println(stderr)
+ }
+ Thread.sleep(50)
+ }
+ }
+ }
+
+ private def stopServer() {
+ process.destroy()
+ process.waitFor()
+ }
+
+ test("test query execution against a Hive Thrift server") {
+ Thread.sleep(5 * 1000)
+ val dataFilePath = getDataFile("data/files/small_kv.txt")
+ val stmt = createStatement()
+ stmt.execute("DROP TABLE IF EXISTS test")
+ stmt.execute("DROP TABLE IF EXISTS test_cached")
+ stmt.execute("CREATE TABLE test(key int, val string)")
+ stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test")
+ stmt.execute("CREATE TABLE test_cached as select * from test limit 4")
+ stmt.execute("CACHE TABLE test_cached")
+
+ var rs = stmt.executeQuery("select count(*) from test")
+ rs.next()
+ assert(rs.getInt(1) === 5)
+
+ rs = stmt.executeQuery("select count(*) from test_cached")
+ rs.next()
+ assert(rs.getInt(1) === 4)
+
+ stmt.close()
+ }
+
+ def getConnection: Connection = {
+ val connectURI = s"jdbc:hive2://localhost:$PORT/"
+ DriverManager.getConnection(connectURI, System.getProperty("user.name"), "")
+ }
+
+ def createStatement(): Statement = getConnection.createStatement()
+}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala
new file mode 100644
index 0000000000..bb2242618f
--- /dev/null
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import java.io.{BufferedReader, PrintWriter}
+import java.text.SimpleDateFormat
+import java.util.Date
+
+import org.apache.hadoop.hive.common.LogUtils
+import org.apache.hadoop.hive.common.LogUtils.LogInitializationException
+
+object TestUtils {
+ val timestamp = new SimpleDateFormat("yyyyMMdd-HHmmss")
+
+ def getWarehousePath(prefix: String): String = {
+ System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-warehouse-" +
+ timestamp.format(new Date)
+ }
+
+ def getMetastorePath(prefix: String): String = {
+ System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-metastore-" +
+ timestamp.format(new Date)
+ }
+
+ // Dummy function for initialize the log4j properties.
+ def init() { }
+
+ // initialize log4j
+ try {
+ LogUtils.initHiveLog4j()
+ } catch {
+ case e: LogInitializationException => // Ignore the error.
+ }
+}
+
+trait TestUtils {
+ var process : Process = null
+ var outputWriter : PrintWriter = null
+ var inputReader : BufferedReader = null
+ var errorReader : BufferedReader = null
+
+ def executeQuery(
+ cmd: String, outputMessage: String = "OK", timeout: Long = 15000): String = {
+ println("Executing: " + cmd + ", expecting output: " + outputMessage)
+ outputWriter.write(cmd + "\n")
+ outputWriter.flush()
+ waitForQuery(timeout, outputMessage)
+ }
+
+ protected def waitForQuery(timeout: Long, message: String): String = {
+ if (waitForOutput(errorReader, message, timeout)) {
+ Thread.sleep(500)
+ readOutput()
+ } else {
+ assert(false, "Didn't find \"" + message + "\" in the output:\n" + readOutput())
+ null
+ }
+ }
+
+ // Wait for the specified str to appear in the output.
+ protected def waitForOutput(
+ reader: BufferedReader, str: String, timeout: Long = 10000): Boolean = {
+ val startTime = System.currentTimeMillis
+ var out = ""
+ while (!out.contains(str) && System.currentTimeMillis < (startTime + timeout)) {
+ out += readFrom(reader)
+ }
+ out.contains(str)
+ }
+
+ // Read stdout output and filter out garbage collection messages.
+ protected def readOutput(): String = {
+ val output = readFrom(inputReader)
+ // Remove GC Messages
+ val filteredOutput = output.lines.filterNot(x => x.contains("[GC") || x.contains("[Full GC"))
+ .mkString("\n")
+ filteredOutput
+ }
+
+ protected def readFrom(reader: BufferedReader): String = {
+ var out = ""
+ var c = 0
+ while (reader.ready) {
+ c = reader.read()
+ out += c.asInstanceOf[Char]
+ }
+ out
+ }
+
+ protected def getDataFile(name: String) = {
+ Thread.currentThread().getContextClassLoader.getResource(name)
+ }
+}