From 0526fea483066086dfc27d1606f74220fe822f7f Mon Sep 17 00:00:00 2001 From: Cheolsoo Park Date: Thu, 4 Jun 2015 13:27:35 -0700 Subject: [SPARK-6909][SQL] Remove Hive Shim code This is a follow-up on #6393. I am removing the following files in this PR. ``` ./sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala ./sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala ``` Basically, I re-factored the shim code as follows- * Rewrote code directly with Hive 0.13 methods, or * Converted code into private methods, or * Extracted code into separate classes But for leftover code that didn't fit in any of these cases, I created a HiveShim object. For eg, helper functions which wrap Hive 0.13 methods to work around Hive bugs are placed here. Author: Cheolsoo Park Closes #6604 from piaozhexiu/SPARK-6909 and squashes the following commits: 5dccc20 [Cheolsoo Park] Remove hive shim code --- .../hive/thriftserver/AbstractSparkSQLDriver.scala | 81 ---------- .../sql/hive/thriftserver/HiveThriftServer2.scala | 10 +- .../SparkExecuteStatementOperation.scala | 176 +++++++++++++++++++++ .../sql/hive/thriftserver/SparkSQLCLIDriver.scala | 6 +- .../sql/hive/thriftserver/SparkSQLCLIService.scala | 7 +- .../sql/hive/thriftserver/SparkSQLDriver.scala | 95 +++++++++++ .../spark/sql/hive/thriftserver/SparkSQLEnv.scala | 4 +- .../hive/thriftserver/SparkSQLSessionManager.scala | 75 +++++++++ .../thriftserver/HiveThriftServer2Suites.scala | 8 +- 9 files changed, 364 insertions(+), 98 deletions(-) delete mode 100644 sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala create mode 100644 sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala create mode 100644 sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala create mode 100644 sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala (limited to 'sql/hive-thriftserver/src') diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala deleted file mode 100644 index 48ac9062af..0000000000 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.thriftserver - -import scala.collection.JavaConversions._ - -import org.apache.commons.lang3.exception.ExceptionUtils -import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} -import org.apache.hadoop.hive.ql.Driver -import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse - -import org.apache.spark.Logging -import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} - -private[hive] abstract class AbstractSparkSQLDriver( - val context: HiveContext = SparkSQLEnv.hiveContext) extends Driver with Logging { - - private[hive] var tableSchema: Schema = _ - private[hive] var hiveResponse: Seq[String] = _ - - override def init(): Unit = { - } - - private def getResultSetSchema(query: context.QueryExecution): Schema = { - val analyzed = query.analyzed - logDebug(s"Result Schema: ${analyzed.output}") - if (analyzed.output.size == 0) { - new Schema(new FieldSchema("Response code", "string", "") :: Nil, null) - } else { - val fieldSchemas = analyzed.output.map { attr => - new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") - } - - new Schema(fieldSchemas, null) - } - } - - override def run(command: String): CommandProcessorResponse = { - // TODO unify the error code - try { - context.sparkContext.setJobDescription(command) - val execution = context.executePlan(context.sql(command).logicalPlan) - hiveResponse = execution.stringResult() - tableSchema = getResultSetSchema(execution) - new CommandProcessorResponse(0) - } catch { - case cause: Throwable => - logError(s"Failed in [$command]", cause) - new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(cause), null) - } - } - - override def close(): Int = { - hiveResponse = null - tableSchema = null - 0 - } - - override def getSchema: Schema = tableSchema - - override def destroy() { - super.destroy() - hiveResponse = null - tableSchema = null - } -} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 94687eeda4..5b391d3dce 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -17,9 +17,6 @@ package org.apache.spark.sql.hive.thriftserver -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - import org.apache.commons.logging.LogFactory import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars @@ -29,12 +26,15 @@ import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd, SparkListenerJobStart} import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab -import org.apache.spark.sql.hive.{HiveContext, HiveShim} import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkContext} +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + /** * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a * `HiveThriftServer2` thrift server. @@ -51,7 +51,7 @@ object HiveThriftServer2 extends Logging { @DeveloperApi def startWithContext(sqlContext: HiveContext): Unit = { val server = new HiveThriftServer2(sqlContext) - sqlContext.setConf("spark.sql.hive.version", HiveShim.version) + sqlContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) server.init(sqlContext.hiveconf) server.start() listener = new HiveThriftServer2Listener(server, sqlContext.conf) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala new file mode 100644 index 0000000000..c0d1266212 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.sql.{Date, Timestamp} +import java.util.{Map => JMap, UUID} + +import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.hive.service.cli._ +import org.apache.hive.service.cli.operation.ExecuteStatementOperation +import org.apache.hive.service.cli.session.HiveSession + +import org.apache.spark.Logging +import org.apache.spark.sql.execution.SetCommand +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} + +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ArrayBuffer, Map => SMap} + +private[hive] class SparkExecuteStatementOperation( + parentSession: HiveSession, + statement: String, + confOverlay: JMap[String, String], + runInBackground: Boolean = true) + (hiveContext: HiveContext, sessionToActivePool: SMap[SessionHandle, String]) + // NOTE: `runInBackground` is set to `false` intentionally to disable asynchronous execution + extends ExecuteStatementOperation(parentSession, statement, confOverlay, false) + with Logging { + + private var result: DataFrame = _ + private var iter: Iterator[SparkRow] = _ + private var dataTypes: Array[DataType] = _ + + def close(): Unit = { + // RDDs will be cleaned automatically upon garbage collection. + logDebug("CLOSING") + } + + def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) { + dataTypes(ordinal) match { + case StringType => + to += from.getString(ordinal) + case IntegerType => + to += from.getInt(ordinal) + case BooleanType => + to += from.getBoolean(ordinal) + case DoubleType => + to += from.getDouble(ordinal) + case FloatType => + to += from.getFloat(ordinal) + case DecimalType() => + to += from.getDecimal(ordinal) + case LongType => + to += from.getLong(ordinal) + case ByteType => + to += from.getByte(ordinal) + case ShortType => + to += from.getShort(ordinal) + case DateType => + to += from.getAs[Date](ordinal) + case TimestampType => + to += from.getAs[Timestamp](ordinal) + case BinaryType | _: ArrayType | _: StructType | _: MapType => + val hiveString = HiveContext.toHiveString((from.get(ordinal), dataTypes(ordinal))) + to += hiveString + } + } + + def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { + validateDefaultFetchOrientation(order) + assertState(OperationState.FINISHED) + setHasResultSet(true) + val resultRowSet: RowSet = RowSetFactory.create(getResultSetSchema, getProtocolVersion) + if (!iter.hasNext) { + resultRowSet + } else { + // maxRowsL here typically maps to java.sql.Statement.getFetchSize, which is an int + val maxRows = maxRowsL.toInt + var curRow = 0 + while (curRow < maxRows && iter.hasNext) { + val sparkRow = iter.next() + val row = ArrayBuffer[Any]() + var curCol = 0 + while (curCol < sparkRow.length) { + if (sparkRow.isNullAt(curCol)) { + row += null + } else { + addNonNullColumnValue(sparkRow, row, curCol) + } + curCol += 1 + } + resultRowSet.addRow(row.toArray.asInstanceOf[Array[Object]]) + curRow += 1 + } + resultRowSet + } + } + + def getResultSetSchema: TableSchema = { + logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") + if (result.queryExecution.analyzed.output.size == 0) { + new TableSchema(new FieldSchema("Result", "string", "") :: Nil) + } else { + val schema = result.queryExecution.analyzed.output.map { attr => + new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") + } + new TableSchema(schema) + } + } + + def run(): Unit = { + val statementId = UUID.randomUUID().toString + logInfo(s"Running query '$statement'") + setState(OperationState.RUNNING) + HiveThriftServer2.listener.onStatementStart( + statementId, + parentSession.getSessionHandle.getSessionId.toString, + statement, + statementId, + parentSession.getUsername) + hiveContext.sparkContext.setJobGroup(statementId, statement) + sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool => + hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) + } + try { + result = hiveContext.sql(statement) + logDebug(result.queryExecution.toString()) + result.queryExecution.logical match { + case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value))), _) => + sessionToActivePool(parentSession.getSessionHandle) = value + logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") + case _ => + } + HiveThriftServer2.listener.onStatementParsed(statementId, result.queryExecution.toString()) + iter = { + val useIncrementalCollect = + hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean + if (useIncrementalCollect) { + result.rdd.toLocalIterator + } else { + result.collect().iterator + } + } + dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray + setHasResultSet(true) + } catch { + // Actually do need to catch Throwable as some failures don't inherit from Exception and + // HiveServer will silently swallow them. + case e: Throwable => + setState(OperationState.ERROR) + HiveThriftServer2.listener.onStatementError( + statementId, e.getMessage, e.getStackTraceString) + logError("Error executing query:", e) + throw new HiveSQLException(e.toString) + } + setState(OperationState.FINISHED) + HiveThriftServer2.listener.onStatementFinish(statementId) + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 14f6f658d9..039cfa40d2 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -32,12 +32,12 @@ import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor} +import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor, CommandProcessorFactory} import org.apache.hadoop.hive.ql.session.SessionState import org.apache.thrift.transport.TSocket import org.apache.spark.Logging -import org.apache.spark.sql.hive.{HiveContext, HiveShim} +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.util.Utils private[hive] object SparkSQLCLIDriver { @@ -267,7 +267,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { } else { var ret = 0 val hconf = conf.asInstanceOf[HiveConf] - val proc: CommandProcessor = HiveShim.getCommandProcessor(Array(tokens(0)), hconf) + val proc: CommandProcessor = CommandProcessorFactory.get(Array(tokens(0)), hconf) if (proc != null) { if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor] || diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index 499e077d72..41f647d5f8 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -21,8 +21,6 @@ import java.io.IOException import java.util.{List => JList} import javax.security.auth.login.LoginException -import scala.collection.JavaConversions._ - import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.shims.ShimLoader @@ -34,7 +32,8 @@ import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import org.apache.spark.util.Utils + +import scala.collection.JavaConversions._ private[hive] class SparkSQLCLIService(hiveContext: HiveContext) extends CLIService @@ -52,7 +51,7 @@ private[hive] class SparkSQLCLIService(hiveContext: HiveContext) try { HiveAuthFactory.loginFromKeytab(hiveConf) sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf) - HiveThriftServerShim.setServerUserName(sparkServiceUGI, this) + setSuperField(this, "serviceUGI", sparkServiceUGI) } catch { case e @ (_: IOException | _: LoginException) => throw new ServiceException("Unable to login to kerberos with given principal/keytab", e) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala new file mode 100644 index 0000000000..77272aecf2 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.util.{ArrayList => JArrayList, List => JList} + +import org.apache.commons.lang3.exception.ExceptionUtils +import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} +import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse + +import org.apache.spark.Logging +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} + +import scala.collection.JavaConversions._ + +private[hive] class SparkSQLDriver( + val context: HiveContext = SparkSQLEnv.hiveContext) + extends Driver + with Logging { + + private[hive] var tableSchema: Schema = _ + private[hive] var hiveResponse: Seq[String] = _ + + override def init(): Unit = { + } + + private def getResultSetSchema(query: context.QueryExecution): Schema = { + val analyzed = query.analyzed + logDebug(s"Result Schema: ${analyzed.output}") + if (analyzed.output.size == 0) { + new Schema(new FieldSchema("Response code", "string", "") :: Nil, null) + } else { + val fieldSchemas = analyzed.output.map { attr => + new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") + } + + new Schema(fieldSchemas, null) + } + } + + override def run(command: String): CommandProcessorResponse = { + // TODO unify the error code + try { + context.sparkContext.setJobDescription(command) + val execution = context.executePlan(context.sql(command).logicalPlan) + hiveResponse = execution.stringResult() + tableSchema = getResultSetSchema(execution) + new CommandProcessorResponse(0) + } catch { + case cause: Throwable => + logError(s"Failed in [$command]", cause) + new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(cause), null) + } + } + + override def close(): Int = { + hiveResponse = null + tableSchema = null + 0 + } + + override def getResults(res: JList[_]): Boolean = { + if (hiveResponse == null) { + false + } else { + res.asInstanceOf[JArrayList[String]].addAll(hiveResponse) + hiveResponse = null + true + } + } + + override def getSchema: Schema = tableSchema + + override def destroy() { + super.destroy() + hiveResponse = null + tableSchema = null + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 7c0c505e2d..79eda1f512 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -22,7 +22,7 @@ import java.io.PrintStream import scala.collection.JavaConversions._ import org.apache.spark.scheduler.StatsReportListener -import org.apache.spark.sql.hive.{HiveShim, HiveContext} +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.util.Utils @@ -56,7 +56,7 @@ private[hive] object SparkSQLEnv extends Logging { hiveContext.metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) hiveContext.metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) - hiveContext.setConf("spark.sql.hive.version", HiveShim.version) + hiveContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) if (log.isDebugEnabled) { hiveContext.hiveconf.getAllProperties.toSeq.sorted.foreach { case (k, v) => diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala new file mode 100644 index 0000000000..357b27f740 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.util.concurrent.Executors + +import org.apache.commons.logging.Log +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hive.service.cli.SessionHandle +import org.apache.hive.service.cli.session.SessionManager +import org.apache.hive.service.cli.thrift.TProtocolVersion + +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ +import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager + +private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) + extends SessionManager + with ReflectedCompositeService { + + private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) + + override def init(hiveConf: HiveConf) { + setSuperField(this, "hiveConf", hiveConf) + + val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) + setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) + getAncestorField[Log](this, 3, "LOG").info( + s"HiveServer2: Async execution pool size $backgroundPoolSize") + + setSuperField(this, "operationManager", sparkSqlOperationManager) + addService(sparkSqlOperationManager) + + initCompositeService(hiveConf) + } + + override def openSession(protocol: TProtocolVersion, + username: String, + passwd: String, + sessionConf: java.util.Map[String, String], + withImpersonation: Boolean, + delegationToken: String): SessionHandle = { + hiveContext.openSession() + val sessionHandle = super.openSession( + protocol, username, passwd, sessionConf, withImpersonation, delegationToken) + val session = super.getSession(sessionHandle) + HiveThriftServer2.listener.onSessionCreated( + session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername) + sessionHandle + } + + override def closeSession(sessionHandle: SessionHandle) { + HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString) + super.closeSession(sessionHandle) + sparkSqlOperationManager.sessionToActivePool -= sessionHandle + + hiveContext.detachSession() + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index a93a3dee43..f57c7083ea 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -40,7 +40,7 @@ import org.apache.thrift.transport.TSocket import org.scalatest.BeforeAndAfterAll import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.sql.hive.HiveShim +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.util.Utils object TestData { @@ -111,7 +111,8 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { withJdbcStatement { statement => val resultSet = statement.executeQuery("SET spark.sql.hive.version") resultSet.next() - assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") + assert(resultSet.getString(1) === + s"spark.sql.hive.version=${HiveContext.hiveExecutionVersion}") } } @@ -365,7 +366,8 @@ class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { withJdbcStatement { statement => val resultSet = statement.executeQuery("SET spark.sql.hive.version") resultSet.next() - assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") + assert(resultSet.getString(1) === + s"spark.sql.hive.version=${HiveContext.hiveExecutionVersion}") } } } -- cgit v1.2.3