aboutsummaryrefslogtreecommitdiff
path: root/sql/hive-thriftserver/src
diff options
context:
space:
mode:
Diffstat (limited to 'sql/hive-thriftserver/src')
-rw-r--r--sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala85
1 files changed, 56 insertions, 29 deletions
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
index 139d8e897b..ebb2575416 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
@@ -23,9 +23,8 @@ import java.sql.{Date, DriverManager, SQLException, Statement}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
-import scala.concurrent.{Await, Promise, future}
+import scala.concurrent.{Await, ExecutionContext, Promise, future}
import scala.io.Source
import scala.util.{Random, Try}
@@ -43,7 +42,7 @@ import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ThreadUtils, Utils}
import org.apache.spark.{Logging, SparkFunSuite}
object TestData {
@@ -356,31 +355,54 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map")
queries.foreach(statement.execute)
-
- val largeJoin = "SELECT COUNT(*) FROM test_map " +
- List.fill(10)("join test_map").mkString(" ")
- val f = future { Thread.sleep(100); statement.cancel(); }
- val e = intercept[SQLException] {
- statement.executeQuery(largeJoin)
+ implicit val ec = ExecutionContext.fromExecutorService(
+ ThreadUtils.newDaemonSingleThreadExecutor("test-jdbc-cancel"))
+ try {
+ // Start a very-long-running query that will take hours to finish, then cancel it in order
+ // to demonstrate that cancellation works.
+ val f = future {
+ statement.executeQuery(
+ "SELECT COUNT(*) FROM test_map " +
+ List.fill(10)("join test_map").mkString(" "))
+ }
+ // Note that this is slightly race-prone: if the cancel is issued before the statement
+ // begins executing then we'll fail with a timeout. As a result, this fixed delay is set
+ // slightly more conservatively than may be strictly necessary.
+ Thread.sleep(1000)
+ statement.cancel()
+ val e = intercept[SQLException] {
+ Await.result(f, 3.minute)
+ }
+ assert(e.getMessage.contains("cancelled"))
+
+ // Cancellation is a no-op if spark.sql.hive.thriftServer.async=false
+ statement.executeQuery("SET spark.sql.hive.thriftServer.async=false")
+ try {
+ val sf = future {
+ statement.executeQuery(
+ "SELECT COUNT(*) FROM test_map " +
+ List.fill(4)("join test_map").mkString(" ")
+ )
+ }
+ // Similarly, this is also slightly race-prone on fast machines where the query above
+ // might race and complete before we issue the cancel.
+ Thread.sleep(1000)
+ statement.cancel()
+ val rs1 = Await.result(sf, 3.minute)
+ rs1.next()
+ assert(rs1.getInt(1) === math.pow(5, 5))
+ rs1.close()
+
+ val rs2 = statement.executeQuery("SELECT COUNT(*) FROM test_map")
+ rs2.next()
+ assert(rs2.getInt(1) === 5)
+ rs2.close()
+ } finally {
+ statement.executeQuery("SET spark.sql.hive.thriftServer.async=true")
+ }
+ } finally {
+ ec.shutdownNow()
}
- assert(e.getMessage contains "cancelled")
- Await.result(f, 3.minute)
-
- // cancel is a noop
- statement.executeQuery("SET spark.sql.hive.thriftServer.async=false")
- val sf = future { Thread.sleep(100); statement.cancel(); }
- val smallJoin = "SELECT COUNT(*) FROM test_map " +
- List.fill(4)("join test_map").mkString(" ")
- val rs1 = statement.executeQuery(smallJoin)
- Await.result(sf, 3.minute)
- rs1.next()
- assert(rs1.getInt(1) === math.pow(5, 5))
- rs1.close()
-
- val rs2 = statement.executeQuery("SELECT COUNT(*) FROM test_map")
- rs2.next()
- assert(rs2.getInt(1) === 5)
- rs2.close()
}
}
@@ -817,6 +839,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl
}
override protected def beforeAll(): Unit = {
+ super.beforeAll()
// Chooses a random port between 10000 and 19999
listeningPort = 10000 + Random.nextInt(10000)
diagnosisBuffer.clear()
@@ -838,7 +861,11 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl
}
override protected def afterAll(): Unit = {
- stopThriftServer()
- logInfo("HiveThriftServer2 stopped")
+ try {
+ stopThriftServer()
+ logInfo("HiveThriftServer2 stopped")
+ } finally {
+ super.afterAll()
+ }
}
}