aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala69
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala14
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala11
-rw-r--r--sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala144
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala35
6 files changed, 165 insertions, 112 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index e658e6fc4d..f23b9c48cf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -84,50 +84,35 @@ case class SetCommand(kv: Option[(String, Option[String])], output: Seq[Attribut
extends LeafNode with Command with Logging {
override protected lazy val sideEffectResult: Seq[Row] = kv match {
- // Set value for the key.
- case Some((key, Some(value))) =>
- if (key == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) {
- logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
+ // Configures the deprecated "mapred.reduce.tasks" property.
+ case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, Some(value))) =>
+ logWarning(
+ s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.")
- context.setConf(SQLConf.SHUFFLE_PARTITIONS, value)
- Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$value"))
- } else {
- context.setConf(key, value)
- Seq(Row(s"$key=$value"))
- }
-
- // Query the value bound to the key.
+ context.setConf(SQLConf.SHUFFLE_PARTITIONS, value)
+ Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$value"))
+
+ // Configures a single property.
+ case Some((key, Some(value))) =>
+ context.setConf(key, value)
+ Seq(Row(s"$key=$value"))
+
+ // Queries all key-value pairs that are set in the SQLConf of the context. Notice that different
+ // from Hive, here "SET -v" is an alias of "SET". (In Hive, "SET" returns all changed properties
+ // while "SET -v" returns all properties.)
+ case Some(("-v", None)) | None =>
+ context.getAllConfs.map { case (k, v) => Row(s"$k=$v") }.toSeq
+
+ // Queries the deprecated "mapred.reduce.tasks" property.
+ case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, None)) =>
+ logWarning(
+ s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
+ s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.")
+ Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${context.numShufflePartitions}"))
+
+ // Queries a single property.
case Some((key, None)) =>
- // TODO (lian) This is just a workaround to make the Simba ODBC driver work.
- // Should remove this once we get the ODBC driver updated.
- if (key == "-v") {
- val hiveJars = Seq(
- "hive-exec-0.12.0.jar",
- "hive-service-0.12.0.jar",
- "hive-common-0.12.0.jar",
- "hive-hwi-0.12.0.jar",
- "hive-0.12.0.jar").mkString(":")
-
- context.getAllConfs.map { case (k, v) =>
- Row(s"$k=$v")
- }.toSeq ++ Seq(
- Row("system:java.class.path=" + hiveJars),
- Row("system:sun.java.command=shark.SharkServer2"))
- } else {
- if (key == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) {
- logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
- s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.")
- Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${context.numShufflePartitions}"))
- } else {
- Seq(Row(s"$key=${context.getConf(key, "<undefined>")}"))
- }
- }
-
- // Query all key-value pairs that are set in the SQLConf of the context.
- case _ =>
- context.getAllConfs.map { case (k, v) =>
- Row(s"$k=$v")
- }.toSeq
+ Seq(Row(s"$key=${context.getConf(key, "<undefined>")}"))
}
override def otherCopyArgs = context :: Nil
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala
index a78311fc48..ecfb74473e 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.hive.thriftserver
+import java.util.jar.Attributes.Name
+
import scala.collection.JavaConversions._
import java.io.IOException
@@ -29,11 +31,12 @@ import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.shims.ShimLoader
import org.apache.hive.service.Service.STATE
import org.apache.hive.service.auth.HiveAuthFactory
-import org.apache.hive.service.cli.CLIService
+import org.apache.hive.service.cli._
import org.apache.hive.service.{AbstractService, Service, ServiceException}
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
+import org.apache.spark.util.Utils
private[hive] class SparkSQLCLIService(hiveContext: HiveContext)
extends CLIService
@@ -60,6 +63,15 @@ private[hive] class SparkSQLCLIService(hiveContext: HiveContext)
initCompositeService(hiveConf)
}
+
+ override def getInfo(sessionHandle: SessionHandle, getInfoType: GetInfoType): GetInfoValue = {
+ getInfoType match {
+ case GetInfoType.CLI_SERVER_NAME => new GetInfoValue("Spark SQL")
+ case GetInfoType.CLI_DBMS_NAME => new GetInfoValue("Spark SQL")
+ case GetInfoType.CLI_DBMS_VER => new GetInfoValue(Utils.sparkVersion)
+ case _ => super.getInfo(sessionHandle, getInfoType)
+ }
+ }
}
private[thriftserver] trait ReflectedCompositeService { this: AbstractService =>
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
index 5042586351..89732c939b 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
@@ -17,10 +17,11 @@
package org.apache.spark.sql.hive.thriftserver
+import scala.collection.JavaConversions._
+
import org.apache.spark.scheduler.StatsReportListener
-import org.apache.spark.sql.hive.HiveContext
+import org.apache.spark.sql.hive.{HiveShim, HiveContext}
import org.apache.spark.{Logging, SparkConf, SparkContext}
-import scala.collection.JavaConversions._
/** A singleton object for the master program. The slaves should not access this. */
private[hive] object SparkSQLEnv extends Logging {
@@ -31,8 +32,10 @@ private[hive] object SparkSQLEnv extends Logging {
def init() {
if (hiveContext == null) {
- sparkContext = new SparkContext(new SparkConf()
- .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}"))
+ val sparkConf = new SparkConf()
+ .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}")
+ .set("spark.sql.hive.version", HiveShim.version)
+ sparkContext = new SparkContext(sparkConf)
sparkContext.addSparkListener(new StatsReportListener())
hiveContext = new HiveContext(sparkContext)
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
index c60e8fa5b1..65d910a0c3 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
@@ -30,42 +30,95 @@ import scala.util.Try
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hive.jdbc.HiveDriver
+import org.apache.hive.service.auth.PlainSaslHelper
+import org.apache.hive.service.cli.GetInfoType
+import org.apache.hive.service.cli.thrift.TCLIService.Client
+import org.apache.hive.service.cli.thrift._
+import org.apache.thrift.protocol.TBinaryProtocol
+import org.apache.thrift.transport.TSocket
import org.scalatest.FunSuite
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.util.getTempFilePath
+import org.apache.spark.sql.hive.HiveShim
/**
* Tests for the HiveThriftServer2 using JDBC.
+ *
+ * NOTE: SPARK_PREPEND_CLASSES is explicitly disabled in this test suite. Assembly jar must be
+ * rebuilt after changing HiveThriftServer2 related code.
*/
class HiveThriftServer2Suite extends FunSuite with Logging {
Class.forName(classOf[HiveDriver].getCanonicalName)
- def startThriftServerWithin(timeout: FiniteDuration = 1.minute)(f: Statement => Unit) {
+ def randomListeningPort = {
+ // Let the system to choose a random available port to avoid collision with other parallel
+ // builds.
+ val socket = new ServerSocket(0)
+ val port = socket.getLocalPort
+ socket.close()
+ port
+ }
+
+ def withJdbcStatement(serverStartTimeout: FiniteDuration = 1.minute)(f: Statement => Unit) {
+ val port = randomListeningPort
+
+ startThriftServer(port, serverStartTimeout) {
+ val jdbcUri = s"jdbc:hive2://${"localhost"}:$port/"
+ val user = System.getProperty("user.name")
+ val connection = DriverManager.getConnection(jdbcUri, user, "")
+ val statement = connection.createStatement()
+
+ try {
+ f(statement)
+ } finally {
+ statement.close()
+ connection.close()
+ }
+ }
+ }
+
+ def withCLIServiceClient(
+ serverStartTimeout: FiniteDuration = 1.minute)(
+ f: ThriftCLIServiceClient => Unit) {
+ val port = randomListeningPort
+
+ startThriftServer(port) {
+ // Transport creation logics below mimics HiveConnection.createBinaryTransport
+ val rawTransport = new TSocket("localhost", port)
+ val user = System.getProperty("user.name")
+ val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport)
+ val protocol = new TBinaryProtocol(transport)
+ val client = new ThriftCLIServiceClient(new Client(protocol))
+
+ transport.open()
+
+ try {
+ f(client)
+ } finally {
+ transport.close()
+ }
+ }
+ }
+
+ def startThriftServer(
+ port: Int,
+ serverStartTimeout: FiniteDuration = 1.minute)(
+ f: => Unit) {
val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator)
val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator)
val warehousePath = getTempFilePath("warehouse")
val metastorePath = getTempFilePath("metastore")
val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true"
- val listeningHost = "localhost"
- val listeningPort = {
- // Let the system to choose a random available port to avoid collision with other parallel
- // builds.
- val socket = new ServerSocket(0)
- val port = socket.getLocalPort
- socket.close()
- port
- }
-
val command =
s"""$startScript
| --master local
| --hiveconf hive.root.logger=INFO,console
| --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri
| --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
- | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=$listeningHost
- | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$listeningPort
+ | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=${"localhost"}
+ | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$port
""".stripMargin.split("\\s+").toSeq
val serverRunning = Promise[Unit]()
@@ -92,31 +145,25 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
}
}
- // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths
- Process(command, None, "SPARK_TESTING" -> "0").run(ProcessLogger(
+ val env = Seq(
+ // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths
+ "SPARK_TESTING" -> "0",
+ // Prevents loading classes out of the assembly jar. Otherwise Utils.sparkVersion can't read
+ // proper version information from the jar manifest.
+ "SPARK_PREPEND_CLASSES" -> "")
+
+ Process(command, None, env: _*).run(ProcessLogger(
captureThriftServerOutput("stdout"),
captureThriftServerOutput("stderr")))
- val jdbcUri = s"jdbc:hive2://$listeningHost:$listeningPort/"
- val user = System.getProperty("user.name")
-
try {
- Await.result(serverRunning.future, timeout)
-
- val connection = DriverManager.getConnection(jdbcUri, user, "")
- val statement = connection.createStatement()
-
- try {
- f(statement)
- } finally {
- statement.close()
- connection.close()
- }
+ Await.result(serverRunning.future, serverStartTimeout)
+ f
} catch {
case cause: Exception =>
cause match {
case _: TimeoutException =>
- logError(s"Failed to start Hive Thrift server within $timeout", cause)
+ logError(s"Failed to start Hive Thrift server within $serverStartTimeout", cause)
case _ =>
}
logError(
@@ -125,8 +172,8 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
|HiveThriftServer2Suite failure output
|=====================================
|HiveThriftServer2 command line: ${command.mkString(" ")}
- |JDBC URI: $jdbcUri
- |User: $user
+ |Binding port: $port
+ |System user: ${System.getProperty("user.name")}
|
|${buffer.mkString("\n")}
|=========================================
@@ -146,7 +193,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
}
test("Test JDBC query execution") {
- startThriftServerWithin() { statement =>
+ withJdbcStatement() { statement =>
val dataFilePath =
Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt")
@@ -168,7 +215,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
}
test("SPARK-3004 regression: result set containing NULL") {
- startThriftServerWithin() { statement =>
+ withJdbcStatement() { statement =>
val dataFilePath =
Thread.currentThread().getContextClassLoader.getResource(
"data/files/small_kv_with_null.txt")
@@ -191,4 +238,33 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
assert(!resultSet.next())
}
}
+
+ test("GetInfo Thrift API") {
+ withCLIServiceClient() { client =>
+ val user = System.getProperty("user.name")
+ val sessionHandle = client.openSession(user, "")
+
+ assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") {
+ client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_NAME).getStringValue
+ }
+
+ assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") {
+ client.getInfo(sessionHandle, GetInfoType.CLI_SERVER_NAME).getStringValue
+ }
+
+ assertResult(true, "Spark version shouldn't be \"Unknown\"") {
+ val version = client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_VER).getStringValue
+ logInfo(s"Spark version: $version")
+ version != "Unknown"
+ }
+ }
+ }
+
+ test("Checks Hive version") {
+ withJdbcStatement() { statement =>
+ val resultSet = statement.executeQuery("SET spark.sql.hive.version")
+ resultSet.next()
+ assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}")
+ }
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index dca5367f24..0fe59f42f2 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -323,7 +323,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
driver.close()
HiveShim.processResults(results)
case _ =>
- sessionState.out.println(tokens(0) + " " + cmd_1)
+ if (sessionState.out != null) {
+ sessionState.out.println(tokens(0) + " " + cmd_1)
+ }
Seq(proc.run(cmd_1).getResponseCode.toString)
}
} catch {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 5918f888c8..b897dff015 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -769,7 +769,7 @@ class HiveQuerySuite extends HiveComparisonTest {
}.toSet
clear()
- // "set" itself returns all config variables currently specified in SQLConf.
+ // "SET" itself returns all config variables currently specified in SQLConf.
// TODO: Should we be listing the default here always? probably...
assert(sql("SET").collect().size == 0)
@@ -778,44 +778,19 @@ class HiveQuerySuite extends HiveComparisonTest {
}
assert(hiveconf.get(testKey, "") == testVal)
- assertResult(Set(testKey -> testVal)) {
- collectResults(sql("SET"))
- }
+ assertResult(Set(testKey -> testVal))(collectResults(sql("SET")))
+ assertResult(Set(testKey -> testVal))(collectResults(sql("SET -v")))
sql(s"SET ${testKey + testKey}=${testVal + testVal}")
assert(hiveconf.get(testKey + testKey, "") == testVal + testVal)
assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) {
collectResults(sql("SET"))
}
-
- // "set key"
- assertResult(Set(testKey -> testVal)) {
- collectResults(sql(s"SET $testKey"))
- }
-
- assertResult(Set(nonexistentKey -> "<undefined>")) {
- collectResults(sql(s"SET $nonexistentKey"))
- }
-
- // Assert that sql() should have the same effects as sql() by repeating the above using sql().
- clear()
- assert(sql("SET").collect().size == 0)
-
- assertResult(Set(testKey -> testVal)) {
- collectResults(sql(s"SET $testKey=$testVal"))
- }
-
- assert(hiveconf.get(testKey, "") == testVal)
- assertResult(Set(testKey -> testVal)) {
- collectResults(sql("SET"))
- }
-
- sql(s"SET ${testKey + testKey}=${testVal + testVal}")
- assert(hiveconf.get(testKey + testKey, "") == testVal + testVal)
assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) {
- collectResults(sql("SET"))
+ collectResults(sql("SET -v"))
}
+ // "SET key"
assertResult(Set(testKey -> testVal)) {
collectResults(sql(s"SET $testKey"))
}