aboutsummaryrefslogtreecommitdiff
path: root/sql/hive-thriftserver/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'sql/hive-thriftserver/src/test')
-rw-r--r--sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala144
1 files changed, 110 insertions, 34 deletions
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
index c60e8fa5b1..65d910a0c3 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
@@ -30,42 +30,95 @@ import scala.util.Try
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hive.jdbc.HiveDriver
+import org.apache.hive.service.auth.PlainSaslHelper
+import org.apache.hive.service.cli.GetInfoType
+import org.apache.hive.service.cli.thrift.TCLIService.Client
+import org.apache.hive.service.cli.thrift._
+import org.apache.thrift.protocol.TBinaryProtocol
+import org.apache.thrift.transport.TSocket
import org.scalatest.FunSuite
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.util.getTempFilePath
+import org.apache.spark.sql.hive.HiveShim
/**
* Tests for the HiveThriftServer2 using JDBC.
+ *
+ * NOTE: SPARK_PREPEND_CLASSES is explicitly disabled in this test suite. Assembly jar must be
+ * rebuilt after changing HiveThriftServer2 related code.
*/
class HiveThriftServer2Suite extends FunSuite with Logging {
Class.forName(classOf[HiveDriver].getCanonicalName)
- def startThriftServerWithin(timeout: FiniteDuration = 1.minute)(f: Statement => Unit) {
+ def randomListeningPort = {
+ // Let the system to choose a random available port to avoid collision with other parallel
+ // builds.
+ val socket = new ServerSocket(0)
+ val port = socket.getLocalPort
+ socket.close()
+ port
+ }
+
+ def withJdbcStatement(serverStartTimeout: FiniteDuration = 1.minute)(f: Statement => Unit) {
+ val port = randomListeningPort
+
+ startThriftServer(port, serverStartTimeout) {
+ val jdbcUri = s"jdbc:hive2://${"localhost"}:$port/"
+ val user = System.getProperty("user.name")
+ val connection = DriverManager.getConnection(jdbcUri, user, "")
+ val statement = connection.createStatement()
+
+ try {
+ f(statement)
+ } finally {
+ statement.close()
+ connection.close()
+ }
+ }
+ }
+
+ def withCLIServiceClient(
+ serverStartTimeout: FiniteDuration = 1.minute)(
+ f: ThriftCLIServiceClient => Unit) {
+ val port = randomListeningPort
+
+ startThriftServer(port) {
+ // Transport creation logics below mimics HiveConnection.createBinaryTransport
+ val rawTransport = new TSocket("localhost", port)
+ val user = System.getProperty("user.name")
+ val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport)
+ val protocol = new TBinaryProtocol(transport)
+ val client = new ThriftCLIServiceClient(new Client(protocol))
+
+ transport.open()
+
+ try {
+ f(client)
+ } finally {
+ transport.close()
+ }
+ }
+ }
+
+ def startThriftServer(
+ port: Int,
+ serverStartTimeout: FiniteDuration = 1.minute)(
+ f: => Unit) {
val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator)
val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator)
val warehousePath = getTempFilePath("warehouse")
val metastorePath = getTempFilePath("metastore")
val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true"
- val listeningHost = "localhost"
- val listeningPort = {
- // Let the system to choose a random available port to avoid collision with other parallel
- // builds.
- val socket = new ServerSocket(0)
- val port = socket.getLocalPort
- socket.close()
- port
- }
-
val command =
s"""$startScript
| --master local
| --hiveconf hive.root.logger=INFO,console
| --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri
| --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
- | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=$listeningHost
- | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$listeningPort
+ | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=${"localhost"}
+ | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$port
""".stripMargin.split("\\s+").toSeq
val serverRunning = Promise[Unit]()
@@ -92,31 +145,25 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
}
}
- // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths
- Process(command, None, "SPARK_TESTING" -> "0").run(ProcessLogger(
+ val env = Seq(
+ // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths
+ "SPARK_TESTING" -> "0",
+ // Prevents loading classes out of the assembly jar. Otherwise Utils.sparkVersion can't read
+ // proper version information from the jar manifest.
+ "SPARK_PREPEND_CLASSES" -> "")
+
+ Process(command, None, env: _*).run(ProcessLogger(
captureThriftServerOutput("stdout"),
captureThriftServerOutput("stderr")))
- val jdbcUri = s"jdbc:hive2://$listeningHost:$listeningPort/"
- val user = System.getProperty("user.name")
-
try {
- Await.result(serverRunning.future, timeout)
-
- val connection = DriverManager.getConnection(jdbcUri, user, "")
- val statement = connection.createStatement()
-
- try {
- f(statement)
- } finally {
- statement.close()
- connection.close()
- }
+ Await.result(serverRunning.future, serverStartTimeout)
+ f
} catch {
case cause: Exception =>
cause match {
case _: TimeoutException =>
- logError(s"Failed to start Hive Thrift server within $timeout", cause)
+ logError(s"Failed to start Hive Thrift server within $serverStartTimeout", cause)
case _ =>
}
logError(
@@ -125,8 +172,8 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
|HiveThriftServer2Suite failure output
|=====================================
|HiveThriftServer2 command line: ${command.mkString(" ")}
- |JDBC URI: $jdbcUri
- |User: $user
+ |Binding port: $port
+ |System user: ${System.getProperty("user.name")}
|
|${buffer.mkString("\n")}
|=========================================
@@ -146,7 +193,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
}
test("Test JDBC query execution") {
- startThriftServerWithin() { statement =>
+ withJdbcStatement() { statement =>
val dataFilePath =
Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt")
@@ -168,7 +215,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
}
test("SPARK-3004 regression: result set containing NULL") {
- startThriftServerWithin() { statement =>
+ withJdbcStatement() { statement =>
val dataFilePath =
Thread.currentThread().getContextClassLoader.getResource(
"data/files/small_kv_with_null.txt")
@@ -191,4 +238,33 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
assert(!resultSet.next())
}
}
+
+ test("GetInfo Thrift API") {
+ withCLIServiceClient() { client =>
+ val user = System.getProperty("user.name")
+ val sessionHandle = client.openSession(user, "")
+
+ assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") {
+ client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_NAME).getStringValue
+ }
+
+ assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") {
+ client.getInfo(sessionHandle, GetInfoType.CLI_SERVER_NAME).getStringValue
+ }
+
+ assertResult(true, "Spark version shouldn't be \"Unknown\"") {
+ val version = client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_VER).getStringValue
+ logInfo(s"Spark version: $version")
+ version != "Unknown"
+ }
+ }
+ }
+
+ test("Checks Hive version") {
+ withJdbcStatement() { statement =>
+ val resultSet = statement.executeQuery("SET spark.sql.hive.version")
+ resultSet.next()
+ assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}")
+ }
+ }
}