aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2014-11-10 17:04:10 -0800
committerMichael Armbrust <michael@databricks.com>2014-11-10 17:04:40 -0800
commit64945f868443fbc59cb34b34c16d782dda0fb63d (patch)
treeb4352a7c282fb8e53669d22602b00933f2927259
parentb3ef06b757383754a9173e81e5179946b12c7922 (diff)
downloadspark-64945f868443fbc59cb34b34c16d782dda0fb63d.tar.gz
spark-64945f868443fbc59cb34b34c16d782dda0fb63d.tar.bz2
spark-64945f868443fbc59cb34b34c16d782dda0fb63d.zip
[SPARK-3971][SQL] Backport #2843 to branch-1.1
This PR backports #2843 to branch-1.1. The key difference is that this one doesn't support Hive 0.13.1 and thus always returns `0.12.0` when `spark.sql.hive.version` is queried. 6 other commits on which #2843 depends were also backported, they are: - #2887 for `SessionState` lifecycle control - #2675, #2823 & #3060 for major test suite refactoring and bug fixes - #2164, for Parquet test suites updates - #2493, for reading `spark.sql.*` configurations Author: Cheng Lian <lian@databricks.com> Author: Cheng Lian <lian.cs.zju@gmail.com> Author: Michael Armbrust <michael@databricks.com> Closes #3113 from liancheng/get-info-for-1.1 and squashes the following commits: d354161 [Cheng Lian] Provides Spark and Hive version in HiveThriftServer2 for branch-1.1 0c2a244 [Michael Armbrust] [SPARK-3646][SQL] Copy SQL configuration from SparkConf when a SQLContext is created. 3202a36 [Michael Armbrust] [SQL] Decrease partitions when testing 7f395b7 [Cheng Lian] [SQL] Fixes race condition in CliSuite 0dd28ec [Cheng Lian] [SQL] Fixes the race condition that may cause test failure 5928b39 [Cheng Lian] [SPARK-3809][SQL] Fixes test suites in hive-thriftserver faeca62 [Cheng Lian] [SPARK-4037][SQL] Removes the SessionState instance created in HiveThriftServer2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala64
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala142
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala17
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala15
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala21
-rw-r--r--sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala36
-rw-r--r--sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala165
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala44
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala7
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala59
13 files changed, 307 insertions, 292 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index a75af94d29..4889fea24a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -75,6 +75,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution =
new this.QueryExecution { val logical = plan }
+ sparkContext.getConf.getAll.foreach {
+ case (key, value) if key.startsWith("spark.sql") => setConf(key, value)
+ case _ =>
+ }
+
/**
* :: DeveloperApi ::
* Allows catalyst LogicalPlans to be executed as a SchemaRDD. Note that the LogicalPlan
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index 031b695169..3429fbad02 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -48,43 +48,35 @@ case class SetCommand(
extends LeafNode with Command with Logging {
override protected[sql] lazy val sideEffectResult: Seq[String] = (key, value) match {
- // Set value for key k.
- case (Some(k), Some(v)) =>
- if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) {
- logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
+ // Configures the deprecated "mapred.reduce.tasks" property.
+ case (Some(SQLConf.Deprecated.MAPRED_REDUCE_TASKS), Some(v)) =>
+ logWarning(
+ s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.")
- context.setConf(SQLConf.SHUFFLE_PARTITIONS, v)
- Array(s"${SQLConf.SHUFFLE_PARTITIONS}=$v")
- } else {
- context.setConf(k, v)
- Array(s"$k=$v")
- }
-
- // Query the value bound to key k.
- case (Some(k), _) =>
- // TODO (lian) This is just a workaround to make the Simba ODBC driver work.
- // Should remove this once we get the ODBC driver updated.
- if (k == "-v") {
- val hiveJars = Seq(
- "hive-exec-0.12.0.jar",
- "hive-service-0.12.0.jar",
- "hive-common-0.12.0.jar",
- "hive-hwi-0.12.0.jar",
- "hive-0.12.0.jar").mkString(":")
-
- Array(
- "system:java.class.path=" + hiveJars,
- "system:sun.java.command=shark.SharkServer2")
- }
- else {
- Array(s"$k=${context.getConf(k, "<undefined>")}")
- }
-
- // Query all key-value pairs that are set in the SQLConf of the context.
- case (None, None) =>
- context.getAllConfs.map { case (k, v) =>
- s"$k=$v"
- }.toSeq
+ context.setConf(SQLConf.SHUFFLE_PARTITIONS, v)
+ Seq(s"${SQLConf.SHUFFLE_PARTITIONS}=$v")
+
+ // Configures a single property.
+ case (Some(k), Some(v)) =>
+ context.setConf(k, v)
+ Seq(s"$k=$v")
+
+ // Queries all key-value pairs that are set in the SQLConf of the context. Notice that different
+ // from Hive, here "SET -v" is an alias of "SET". (In Hive, "SET" returns all changed properties
+ // while "SET -v" returns all properties.)
+ case (Some("-v") | None, None) =>
+ context.getAllConfs.map { case (k, v) => s"$k=$v" }.toSeq
+
+ // Queries the deprecated "mapred.reduce.tasks" property.
+ case (Some(SQLConf.Deprecated.MAPRED_REDUCE_TASKS), None) =>
+ logWarning(
+ s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
+ s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.")
+ Seq(s"${SQLConf.SHUFFLE_PARTITIONS}=${context.numShufflePartitions}")
+
+ // Queries a single property.
+ case (Some(k), None) =>
+ Seq(s"$k=${context.getConf(k, "<undefined>")}")
case _ =>
throw new IllegalArgumentException()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
index f2389f8f05..4fdfc2ba1b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -17,9 +17,18 @@
package org.apache.spark.sql.test
+import org.apache.spark.sql.{SQLConf, SQLContext}
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.sql.SQLContext
/** A SQLContext that can be used for local testing. */
object TestSQLContext
- extends SQLContext(new SparkContext("local", "TestSQLContext", new SparkConf()))
+ extends SQLContext(
+ new SparkContext(
+ "local[2]",
+ "TestSQLContext",
+ new SparkConf().set("spark.sql.testkey", "true"))) {
+
+ /** Fewer partitions to speed up testing. */
+ override private[spark] def numShufflePartitions: Int =
+ getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
index 584f71b3c1..60701f0e15 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
@@ -17,16 +17,25 @@
package org.apache.spark.sql
+import org.scalatest.FunSuiteLike
+
import org.apache.spark.sql.test._
/* Implicits */
import TestSQLContext._
-class SQLConfSuite extends QueryTest {
+class SQLConfSuite extends QueryTest with FunSuiteLike {
val testKey = "test.key.0"
val testVal = "test.val.0"
+ test("propagate from spark conf") {
+ // We create a new context here to avoid order dependence with other tests that might call
+ // clear().
+ val newContext = new SQLContext(TestSQLContext.sparkContext)
+ assert(newContext.getConf("spark.sql.testkey", "false") == "true")
+ }
+
test("programmatic ways of basic setting and getting") {
clear()
assert(getAllConfs.size === 0)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index 42923b6a28..c6b790a4b6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -63,8 +63,7 @@ case class AllDataTypes(
doubleField: Double,
shortField: Short,
byteField: Byte,
- booleanField: Boolean,
- binaryField: Array[Byte])
+ booleanField: Boolean)
case class AllDataTypesWithNonPrimitiveType(
stringField: String,
@@ -75,13 +74,14 @@ case class AllDataTypesWithNonPrimitiveType(
shortField: Short,
byteField: Byte,
booleanField: Boolean,
- binaryField: Array[Byte],
array: Seq[Int],
arrayContainsNull: Seq[Option[Int]],
map: Map[Int, Long],
mapValueContainsNull: Map[Int, Option[Long]],
data: Data)
+case class BinaryData(binaryData: Array[Byte])
+
class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll {
TestData // Load test data tables.
@@ -117,26 +117,26 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
test("Read/Write All Types") {
val tempDir = getTempFilePath("parquetTest").getCanonicalPath
val range = (0 to 255)
- TestSQLContext.sparkContext.parallelize(range)
- .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0,
- (0 to x).map(_.toByte).toArray))
- .saveAsParquetFile(tempDir)
- val result = parquetFile(tempDir).collect()
- range.foreach {
- i =>
- assert(result(i).getString(0) == s"$i", s"row $i String field did not match, got ${result(i).getString(0)}")
- assert(result(i).getInt(1) === i)
- assert(result(i).getLong(2) === i.toLong)
- assert(result(i).getFloat(3) === i.toFloat)
- assert(result(i).getDouble(4) === i.toDouble)
- assert(result(i).getShort(5) === i.toShort)
- assert(result(i).getByte(6) === i.toByte)
- assert(result(i).getBoolean(7) === (i % 2 == 0))
- assert(result(i)(8) === (0 to i).map(_.toByte).toArray)
- }
+ val data = sparkContext.parallelize(range)
+ .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0))
+
+ data.saveAsParquetFile(tempDir)
+
+ checkAnswer(
+ parquetFile(tempDir),
+ data.toSchemaRDD.collect().toSeq)
}
- test("Treat binary as string") {
+ test("read/write binary data") {
+ // Since equality for Array[Byte] is broken we test this separately.
+ val tempDir = getTempFilePath("parquetTest").getCanonicalPath
+ sparkContext.parallelize(BinaryData("test".getBytes("utf8")) :: Nil).saveAsParquetFile(tempDir)
+ parquetFile(tempDir)
+ .map(r => new String(r(0).asInstanceOf[Array[Byte]], "utf8"))
+ .collect().toSeq == Seq("test")
+ }
+
+ ignore("Treat binary as string") {
val oldIsParquetBinaryAsString = TestSQLContext.isParquetBinaryAsString
// Create the test file.
@@ -151,37 +151,16 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
StructField("c2", BinaryType, false) :: Nil)
val schemaRDD1 = applySchema(rowRDD, schema)
schemaRDD1.saveAsParquetFile(path)
- val resultWithBinary = parquetFile(path).collect
- range.foreach {
- i =>
- assert(resultWithBinary(i).getInt(0) === i)
- assert(resultWithBinary(i)(1) === s"val_$i".getBytes)
- }
-
- TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, "true")
- // This ParquetRelation always use Parquet types to derive output.
- val parquetRelation = new ParquetRelation(
- path.toString,
- Some(TestSQLContext.sparkContext.hadoopConfiguration),
- TestSQLContext) {
- override val output =
- ParquetTypesConverter.convertToAttributes(
- ParquetTypesConverter.readMetaData(new Path(path), conf).getFileMetaData.getSchema,
- TestSQLContext.isParquetBinaryAsString)
- }
- val schemaRDD = new SchemaRDD(TestSQLContext, parquetRelation)
- val resultWithString = schemaRDD.collect
- range.foreach {
- i =>
- assert(resultWithString(i).getInt(0) === i)
- assert(resultWithString(i)(1) === s"val_$i")
- }
+ checkAnswer(
+ parquetFile(path).select('c1, 'c2.cast(StringType)),
+ schemaRDD1.select('c1, 'c2.cast(StringType)).collect().toSeq)
- schemaRDD.registerTempTable("tmp")
+ setConf(SQLConf.PARQUET_BINARY_AS_STRING, "true")
+ parquetFile(path).printSchema()
checkAnswer(
- sql("SELECT c1, c2 FROM tmp WHERE c2 = 'val_5' OR c2 = 'val_7'"),
- (5, "val_5") ::
- (7, "val_7") :: Nil)
+ parquetFile(path),
+ schemaRDD1.select('c1, 'c2.cast(StringType)).collect().toSeq)
+
// Set it back.
TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, oldIsParquetBinaryAsString.toString)
@@ -284,34 +263,19 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
test("Read/Write All Types with non-primitive type") {
val tempDir = getTempFilePath("parquetTest").getCanonicalPath
val range = (0 to 255)
- TestSQLContext.sparkContext.parallelize(range)
+ val data = sparkContext.parallelize(range)
.map(x => AllDataTypesWithNonPrimitiveType(
s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0,
- (0 to x).map(_.toByte).toArray,
(0 until x),
(0 until x).map(Option(_).filter(_ % 3 == 0)),
(0 until x).map(i => i -> i.toLong).toMap,
(0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None),
Data((0 until x), Nested(x, s"$x"))))
- .saveAsParquetFile(tempDir)
- val result = parquetFile(tempDir).collect()
- range.foreach {
- i =>
- assert(result(i).getString(0) == s"$i", s"row $i String field did not match, got ${result(i).getString(0)}")
- assert(result(i).getInt(1) === i)
- assert(result(i).getLong(2) === i.toLong)
- assert(result(i).getFloat(3) === i.toFloat)
- assert(result(i).getDouble(4) === i.toDouble)
- assert(result(i).getShort(5) === i.toShort)
- assert(result(i).getByte(6) === i.toByte)
- assert(result(i).getBoolean(7) === (i % 2 == 0))
- assert(result(i)(8) === (0 to i).map(_.toByte).toArray)
- assert(result(i)(9) === (0 until i))
- assert(result(i)(10) === (0 until i).map(i => if (i % 3 == 0) i else null))
- assert(result(i)(11) === (0 until i).map(i => i -> i.toLong).toMap)
- assert(result(i)(12) === (0 until i).map(i => i -> i.toLong).toMap + (i -> null))
- assert(result(i)(13) === new GenericRow(Array[Any]((0 until i), new GenericRow(Array[Any](i, s"$i")))))
- }
+ data.saveAsParquetFile(tempDir)
+
+ checkAnswer(
+ parquetFile(tempDir),
+ data.toSchemaRDD.collect().toSeq)
}
test("self-join parquet files") {
@@ -408,23 +372,6 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
}
}
- test("Saving case class RDD table to file and reading it back in") {
- val file = getTempFilePath("parquet")
- val path = file.toString
- val rdd = TestSQLContext.sparkContext.parallelize((1 to 100))
- .map(i => TestRDDEntry(i, s"val_$i"))
- rdd.saveAsParquetFile(path)
- val readFile = parquetFile(path)
- readFile.registerTempTable("tmpx")
- val rdd_copy = sql("SELECT * FROM tmpx").collect()
- val rdd_orig = rdd.collect()
- for(i <- 0 to 99) {
- assert(rdd_copy(i).apply(0) === rdd_orig(i).key, s"key error in line $i")
- assert(rdd_copy(i).apply(1) === rdd_orig(i).value, s"value error in line $i")
- }
- Utils.deleteRecursively(file)
- }
-
test("Read a parquet file instead of a directory") {
val file = getTempFilePath("parquet")
val path = file.toString
@@ -457,32 +404,19 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
sql("INSERT OVERWRITE INTO dest SELECT * FROM source").collect()
val rdd_copy1 = sql("SELECT * FROM dest").collect()
assert(rdd_copy1.size === 100)
- assert(rdd_copy1(0).apply(0) === 1)
- assert(rdd_copy1(0).apply(1) === "val_1")
- // TODO: why does collecting break things? It seems InsertIntoParquet::execute() is
- // executed twice otherwise?!
+
sql("INSERT INTO dest SELECT * FROM source")
- val rdd_copy2 = sql("SELECT * FROM dest").collect()
+ val rdd_copy2 = sql("SELECT * FROM dest").collect().sortBy(_.getInt(0))
assert(rdd_copy2.size === 200)
- assert(rdd_copy2(0).apply(0) === 1)
- assert(rdd_copy2(0).apply(1) === "val_1")
- assert(rdd_copy2(99).apply(0) === 100)
- assert(rdd_copy2(99).apply(1) === "val_100")
- assert(rdd_copy2(100).apply(0) === 1)
- assert(rdd_copy2(100).apply(1) === "val_1")
Utils.deleteRecursively(dirname)
}
test("Insert (appending) to same table via Scala API") {
- // TODO: why does collecting break things? It seems InsertIntoParquet::execute() is
- // executed twice otherwise?!
sql("INSERT INTO testsource SELECT * FROM testsource")
val double_rdd = sql("SELECT * FROM testsource").collect()
assert(double_rdd != null)
assert(double_rdd.size === 30)
- for(i <- (0 to 14)) {
- assert(double_rdd(i) === double_rdd(i+15), s"error: lines $i and ${i+15} to not match")
- }
+
// let's restore the original test data
Utils.deleteRecursively(ParquetTestData.testDir)
ParquetTestData.writeFile()
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
index cadf7aaf42..161f8c6199 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
@@ -17,11 +17,8 @@
package org.apache.spark.sql.hive.thriftserver
-import scala.collection.JavaConversions._
-
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.hive.conf.HiveConf
-import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hive.service.cli.thrift.ThriftBinaryCLIService
import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor}
@@ -38,24 +35,12 @@ private[hive] object HiveThriftServer2 extends Logging {
def main(args: Array[String]) {
val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2")
-
if (!optionsProcessor.process(args)) {
System.exit(-1)
}
- val ss = new SessionState(new HiveConf(classOf[SessionState]))
-
- // Set all properties specified via command line.
- val hiveConf: HiveConf = ss.getConf
- hiveConf.getAllProperties.toSeq.sortBy(_._1).foreach { case (k, v) =>
- logDebug(s"HiveConf var: $k=$v")
- }
-
- SessionState.start(ss)
-
logInfo("Starting SparkContext")
SparkSQLEnv.init()
- SessionState.start(ss)
Runtime.getRuntime.addShutdownHook(
new Thread() {
@@ -67,7 +52,7 @@ private[hive] object HiveThriftServer2 extends Logging {
try {
val server = new HiveThriftServer2(SparkSQLEnv.hiveContext)
- server.init(hiveConf)
+ server.init(SparkSQLEnv.hiveContext.hiveconf)
server.start()
logInfo("HiveThriftServer2 started")
} catch {
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala
index 42cbf363b2..94ec9978af 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala
@@ -17,18 +17,18 @@
package org.apache.spark.sql.hive.thriftserver
-import scala.collection.JavaConversions._
-
import java.io.IOException
import java.util.{List => JList}
import javax.security.auth.login.LoginException
+import scala.collection.JavaConversions._
+
import org.apache.commons.logging.Log
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.shims.ShimLoader
import org.apache.hive.service.Service.STATE
import org.apache.hive.service.auth.HiveAuthFactory
-import org.apache.hive.service.cli.CLIService
+import org.apache.hive.service.cli._
import org.apache.hive.service.{AbstractService, Service, ServiceException}
import org.apache.spark.sql.hive.HiveContext
@@ -57,6 +57,15 @@ private[hive] class SparkSQLCLIService(hiveContext: HiveContext)
initCompositeService(hiveConf)
}
+
+ override def getInfo(sessionHandle: SessionHandle, getInfoType: GetInfoType): GetInfoValue = {
+ getInfoType match {
+ case GetInfoType.CLI_SERVER_NAME => new GetInfoValue("Spark SQL")
+ case GetInfoType.CLI_DBMS_NAME => new GetInfoValue("Spark SQL")
+ case GetInfoType.CLI_DBMS_VER => new GetInfoValue(hiveContext.sparkContext.version)
+ case _ => super.getInfo(sessionHandle, getInfoType)
+ }
+ }
}
private[thriftserver] trait ReflectedCompositeService { this: AbstractService =>
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
index 582264eb59..e07402c56c 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
@@ -17,12 +17,11 @@
package org.apache.spark.sql.hive.thriftserver
-import org.apache.hadoop.hive.ql.session.SessionState
+import scala.collection.JavaConversions._
-import org.apache.spark.scheduler.{SplitInfo, StatsReportListener}
-import org.apache.spark.Logging
+import org.apache.spark.scheduler.StatsReportListener
import org.apache.spark.sql.hive.HiveContext
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.{Logging, SparkConf, SparkContext}
/** A singleton object for the master program. The slaves should not access this. */
private[hive] object SparkSQLEnv extends Logging {
@@ -33,14 +32,18 @@ private[hive] object SparkSQLEnv extends Logging {
def init() {
if (hiveContext == null) {
- sparkContext = new SparkContext(new SparkConf()
- .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}"))
+ val sparkConf = new SparkConf()
+ .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}")
+ .set("spark.sql.hive.version", "0.12.0")
+ sparkContext = new SparkContext(sparkConf)
sparkContext.addSparkListener(new StatsReportListener())
+ hiveContext = new HiveContext(sparkContext)
- hiveContext = new HiveContext(sparkContext) {
- @transient override lazy val sessionState = SessionState.get()
- @transient override lazy val hiveconf = sessionState.getConf
+ if (log.isDebugEnabled) {
+ hiveContext.hiveconf.getAllProperties.toSeq.sorted.foreach { case (k, v) =>
+ logDebug(s"HiveConf var: $k=$v")
+ }
}
}
}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
index 3475c2c9db..e8ffbc5b95 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
@@ -18,15 +18,13 @@
package org.apache.spark.sql.hive.thriftserver
+import java.io._
+
import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
-import scala.concurrent.{Await, Future, Promise}
+import scala.concurrent.{Await, Promise}
import scala.sys.process.{Process, ProcessLogger}
-import java.io._
-import java.util.concurrent.atomic.AtomicInteger
-
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.scalatest.{BeforeAndAfterAll, FunSuite}
@@ -53,17 +51,19 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging {
""".stripMargin.split("\\s+").toSeq ++ extraArgs
}
- // AtomicInteger is needed because stderr and stdout of the forked process are handled in
- // different threads.
- val next = new AtomicInteger(0)
+ var next = 0
val foundAllExpectedAnswers = Promise.apply[Unit]()
val queryStream = new ByteArrayInputStream(queries.mkString("\n").getBytes)
val buffer = new ArrayBuffer[String]()
+ val lock = new Object
- def captureOutput(source: String)(line: String) {
+ def captureOutput(source: String)(line: String): Unit = lock.synchronized {
buffer += s"$source> $line"
- if (line.contains(expectedAnswers(next.get()))) {
- if (next.incrementAndGet() == expectedAnswers.size) {
+ // If we haven't found all expected answers and another expected answer comes up...
+ if (next < expectedAnswers.size && line.startsWith(expectedAnswers(next))) {
+ next += 1
+ // If all expected answers have been found...
+ if (next == expectedAnswers.size) {
foundAllExpectedAnswers.trySuccess(())
}
}
@@ -73,11 +73,6 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging {
val process = (Process(command) #< queryStream).run(
ProcessLogger(captureOutput("stdout"), captureOutput("stderr")))
- Future {
- val exitValue = process.exitValue()
- logInfo(s"Spark SQL CLI process exit value: $exitValue")
- }
-
try {
Await.result(foundAllExpectedAnswers.future, timeout)
} catch { case cause: Throwable =>
@@ -88,14 +83,15 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging {
|=======================
|Spark SQL CLI command line: ${command.mkString(" ")}
|
- |Executed query ${next.get()} "${queries(next.get())}",
- |But failed to capture expected output "${expectedAnswers(next.get())}" within $timeout.
+ |Executed query $next "${queries(next)}",
+ |But failed to capture expected output "${expectedAnswers(next)}" within $timeout.
|
|${buffer.mkString("\n")}
|===========================
|End CliSuite failure output
|===========================
""".stripMargin, cause)
+ throw cause
} finally {
warehousePath.delete()
metastorePath.delete()
@@ -107,7 +103,7 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging {
val dataFilePath =
Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt")
- runCliWithin(1.minute)(
+ runCliWithin(3.minute)(
"CREATE TABLE hive_test(key INT, val STRING);"
-> "OK",
"SHOW TABLES;"
@@ -118,7 +114,7 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging {
-> "Time taken: ",
"SELECT COUNT(*) FROM hive_test;"
-> "5",
- "DROP TABLE hive_test"
+ "DROP TABLE hive_test;"
-> "Time taken: "
)
}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
index 38977ff162..08b4cc1c42 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
@@ -17,32 +17,39 @@
package org.apache.spark.sql.hive.thriftserver
-import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.ExecutionContext.Implicits.global
-import scala.concurrent.duration._
-import scala.concurrent.{Await, Future, Promise}
-import scala.sys.process.{Process, ProcessLogger}
-
import java.io.File
import java.net.ServerSocket
import java.sql.{DriverManager, Statement}
import java.util.concurrent.TimeoutException
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.duration._
+import scala.concurrent.{Await, Promise}
+import scala.sys.process.{Process, ProcessLogger}
+
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hive.jdbc.HiveDriver
+import org.apache.hive.service.auth.PlainSaslHelper
+import org.apache.hive.service.cli.GetInfoType
+import org.apache.hive.service.cli.thrift.TCLIService.Client
+import org.apache.hive.service.cli.thrift._
+import org.apache.thrift.protocol.TBinaryProtocol
+import org.apache.thrift.transport.TSocket
import org.scalatest.FunSuite
-import org.apache.spark.Logging
+import org.apache.spark.{SparkContext, Logging}
import org.apache.spark.sql.catalyst.util.getTempFilePath
/**
* Tests for the HiveThriftServer2 using JDBC.
+ *
+ * NOTE: SPARK_PREPEND_CLASSES is explicitly disabled in this test suite. Assembly jar must be
+ * rebuilt after changing HiveThriftServer2 related code.
*/
class HiveThriftServer2Suite extends FunSuite with Logging {
Class.forName(classOf[HiveDriver].getCanonicalName)
- private val listeningHost = "localhost"
- private val listeningPort = {
+ def randomListeningPort = {
// Let the system to choose a random available port to avoid collision with other parallel
// builds.
val socket = new ServerSocket(0)
@@ -51,61 +58,91 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
port
}
- private val warehousePath = getTempFilePath("warehouse")
- private val metastorePath = getTempFilePath("metastore")
- private val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true"
+ def withJdbcStatement(serverStartTimeout: FiniteDuration = 1.minute)(f: Statement => Unit) {
+ val port = randomListeningPort
+
+ startThriftServer(port, serverStartTimeout) {
+ val jdbcUri = s"jdbc:hive2://${"localhost"}:$port/"
+ val user = System.getProperty("user.name")
+ val connection = DriverManager.getConnection(jdbcUri, user, "")
+ val statement = connection.createStatement()
- def startThriftServerWithin(timeout: FiniteDuration = 30.seconds)(f: Statement => Unit) {
- val serverScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator)
+ try {
+ f(statement)
+ } finally {
+ statement.close()
+ connection.close()
+ }
+ }
+ }
+
+ def withCLIServiceClient(
+ serverStartTimeout: FiniteDuration = 1.minute)(
+ f: ThriftCLIServiceClient => Unit) {
+ val port = randomListeningPort
+
+ startThriftServer(port) {
+ // Transport creation logics below mimics HiveConnection.createBinaryTransport
+ val rawTransport = new TSocket("localhost", port)
+ val user = System.getProperty("user.name")
+ val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport)
+ val protocol = new TBinaryProtocol(transport)
+ val client = new ThriftCLIServiceClient(new Client(protocol))
+
+ transport.open()
+
+ try {
+ f(client)
+ } finally {
+ transport.close()
+ }
+ }
+ }
+ def startThriftServer(
+ port: Int,
+ serverStartTimeout: FiniteDuration = 1.minute)(
+ f: => Unit) {
+ val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator)
+
+ val warehousePath = getTempFilePath("warehouse")
+ val metastorePath = getTempFilePath("metastore")
+ val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true"
val command =
- s"""$serverScript
+ s"""$startScript
| --master local
| --hiveconf hive.root.logger=INFO,console
| --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri
| --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
- | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=$listeningHost
- | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$listeningPort
+ | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=${"localhost"}
+ | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$port
""".stripMargin.split("\\s+").toSeq
- val serverStarted = Promise[Unit]()
+ val serverRunning = Promise[Unit]()
val buffer = new ArrayBuffer[String]()
+ val lock = new Object
- def captureOutput(source: String)(line: String) {
+ def captureOutput(source: String)(line: String): Unit = lock.synchronized {
buffer += s"$source> $line"
if (line.contains("ThriftBinaryCLIService listening on")) {
- serverStarted.success(())
+ serverRunning.success(())
}
}
- val process = Process(command).run(
- ProcessLogger(captureOutput("stdout"), captureOutput("stderr")))
-
- Future {
- val exitValue = process.exitValue()
- logInfo(s"Spark SQL Thrift server process exit value: $exitValue")
- }
+ // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths
+ val env = Seq("SPARK_TESTING" -> "0")
- val jdbcUri = s"jdbc:hive2://$listeningHost:$listeningPort/"
- val user = System.getProperty("user.name")
+ val process = Process(command, None, env: _*).run(
+ ProcessLogger(captureOutput("stdout"), captureOutput("stderr")))
try {
- Await.result(serverStarted.future, timeout)
-
- val connection = DriverManager.getConnection(jdbcUri, user, "")
- val statement = connection.createStatement()
-
- try {
- f(statement)
- } finally {
- statement.close()
- connection.close()
- }
+ Await.result(serverRunning.future, serverStartTimeout)
+ f
} catch {
case cause: Exception =>
cause match {
case _: TimeoutException =>
- logError(s"Failed to start Hive Thrift server within $timeout", cause)
+ logError(s"Failed to start Hive Thrift server within $serverStartTimeout", cause)
case _ =>
}
logError(
@@ -114,14 +151,15 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
|HiveThriftServer2Suite failure output
|=====================================
|HiveThriftServer2 command line: ${command.mkString(" ")}
- |JDBC URI: $jdbcUri
- |User: $user
+ |Binding port: $port
+ |System user: ${System.getProperty("user.name")}
|
|${buffer.mkString("\n")}
|=========================================
|End HiveThriftServer2Suite failure output
|=========================================
""".stripMargin, cause)
+ throw cause
} finally {
warehousePath.delete()
metastorePath.delete()
@@ -130,14 +168,16 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
}
test("Test JDBC query execution") {
- startThriftServerWithin() { statement =>
+ withJdbcStatement() { statement =>
val dataFilePath =
Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt")
- val queries = Seq(
- "CREATE TABLE test(key INT, val STRING)",
- s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test",
- "CACHE TABLE test")
+ val queries =
+ s"""SET spark.sql.shuffle.partitions=3;
+ |CREATE TABLE test(key INT, val STRING);
+ |LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test;
+ |CACHE TABLE test;
+ """.stripMargin.split(";").map(_.trim).filter(_.nonEmpty)
queries.foreach(statement.execute)
@@ -150,7 +190,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
}
test("SPARK-3004 regression: result set containing NULL") {
- startThriftServerWithin() { statement =>
+ withJdbcStatement() { statement =>
val dataFilePath =
Thread.currentThread().getContextClassLoader.getResource(
"data/files/small_kv_with_null.txt")
@@ -173,4 +213,31 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
assert(!resultSet.next())
}
}
+
+ test("GetInfo Thrift API") {
+ withCLIServiceClient() { client =>
+ val user = System.getProperty("user.name")
+ val sessionHandle = client.openSession(user, "")
+
+ assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") {
+ client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_NAME).getStringValue
+ }
+
+ assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") {
+ client.getInfo(sessionHandle, GetInfoType.CLI_SERVER_NAME).getStringValue
+ }
+
+ assertResult(SparkContext.SPARK_VERSION, "Spark version shouldn't be \"Unknown\"") {
+ client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_VER).getStringValue
+ }
+ }
+ }
+
+ test("Checks Hive version") {
+ withJdbcStatement() { statement =>
+ val resultSet = statement.executeQuery("SET spark.sql.hive.version")
+ resultSet.next()
+ assert(resultSet.getString(1) === s"spark.sql.hive.version=0.12.0")
+ }
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index d9b2bc7348..b44a94c6ae 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -222,17 +222,29 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
}
/**
- * SQLConf and HiveConf contracts: when the hive session is first initialized, params in
- * HiveConf will get picked up by the SQLConf. Additionally, any properties set by
- * set() or a SET command inside sql() will be set in the SQLConf *as well as*
- * in the HiveConf.
+ * SQLConf and HiveConf contracts:
+ *
+ * 1. reuse existing started SessionState if any
+ * 2. when the Hive session is first initialized, params in HiveConf will get picked up by the
+ * SQLConf. Additionally, any properties set by set() or a SET command inside sql() will be
+ * set in the SQLConf *as well as* in the HiveConf.
*/
- @transient protected[hive] lazy val hiveconf = new HiveConf(classOf[SessionState])
- @transient protected[hive] lazy val sessionState = {
- val ss = new SessionState(hiveconf)
- setConf(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf.
- ss
- }
+ @transient protected[hive] lazy val (hiveconf, sessionState) =
+ Option(SessionState.get())
+ .orElse {
+ val newState = new SessionState(new HiveConf(classOf[SessionState]))
+ // Only starts newly created `SessionState` instance. Any existing `SessionState` instance
+ // returned by `SessionState.get()` must be the most recently started one.
+ SessionState.start(newState)
+ Some(newState)
+ }
+ .map { state =>
+ setConf(state.getConf.getAllProperties)
+ if (state.out == null) state.out = new PrintStream(outputBuffer, true, "UTF-8")
+ if (state.err == null) state.err = new PrintStream(outputBuffer, true, "UTF-8")
+ (state.getConf, state)
+ }
+ .get
sessionState.err = new PrintStream(outputBuffer, true, "UTF-8")
sessionState.out = new PrintStream(outputBuffer, true, "UTF-8")
@@ -290,6 +302,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
SessionState.start(sessionState)
+ // Makes sure the session represented by the `sessionState` field is activated. This implies
+ // Spark SQL Hive support uses a single `SessionState` for all Hive operations and breaks
+ // session isolation under multi-user scenarios (i.e. HiveThriftServer2).
+ // TODO Fix session isolation
+ if (SessionState.get() != sessionState) {
+ SessionState.start(sessionState)
+ }
+
proc match {
case driver: Driver =>
driver.init()
@@ -306,7 +326,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
driver.destroy()
results
case _ =>
- sessionState.out.println(tokens(0) + " " + cmd_1)
+ if (sessionState.out != null) {
+ sessionState.out.println(tokens(0) + " " + cmd_1)
+ }
Seq(proc.run(cmd_1).getResponseCode.toString)
}
} catch {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
index 8bb2216b7b..094e58e986 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
@@ -35,12 +35,13 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.plans.logical.{CacheCommand, LogicalPlan, NativeCommand}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.hive._
+import org.apache.spark.sql.SQLConf
/* Implicit conversions */
import scala.collection.JavaConversions._
object TestHive
- extends TestHiveContext(new SparkContext("local", "TestSQLContext", new SparkConf()))
+ extends TestHiveContext(new SparkContext("local[2]", "TestSQLContext", new SparkConf()))
/**
* A locally running test instance of Spark's Hive execution engine.
@@ -90,6 +91,10 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
override def executePlan(plan: LogicalPlan): this.QueryExecution =
new this.QueryExecution { val logical = plan }
+ /** Fewer partitions to speed up testing. */
+ override private[spark] def numShufflePartitions: Int =
+ getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt
+
/**
* Returns the value of specified environmental variable as a [[java.io.File]] after checking
* to ensure it exists
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index d258743195..cdf9844207 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.hive.execution
import scala.util.Try
-import org.apache.spark.sql.{SchemaRDD, Row}
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
@@ -313,10 +312,10 @@ class HiveQuerySuite extends HiveComparisonTest {
"SELECT srcalias.KEY, SRCALIAS.value FROM sRc SrCAlias WHERE SrCAlias.kEy < 15")
test("case sensitivity: registered table") {
- val testData: SchemaRDD =
+ val testData =
TestHive.sparkContext.parallelize(
TestData(1, "str1") ::
- TestData(2, "str2") :: Nil)
+ TestData(2, "str2") :: Nil).toSchemaRDD
testData.registerTempTable("REGisteredTABle")
assertResult(Array(Array(2, "str2"))) {
@@ -327,7 +326,7 @@ class HiveQuerySuite extends HiveComparisonTest {
def isExplanation(result: SchemaRDD) = {
val explanation = result.select('plan).collect().map { case Row(plan: String) => plan }
- explanation.exists(_ == "== Physical Plan ==")
+ explanation.contains("== Physical Plan ==")
}
test("SPARK-1704: Explain commands as a SchemaRDD") {
@@ -467,10 +466,10 @@ class HiveQuerySuite extends HiveComparisonTest {
}
// Describe a registered temporary table.
- val testData: SchemaRDD =
+ val testData =
TestHive.sparkContext.parallelize(
TestData(1, "str1") ::
- TestData(1, "str2") :: Nil)
+ TestData(1, "str2") :: Nil).toSchemaRDD
testData.registerTempTable("test_describe_commands2")
assertResult(
@@ -520,10 +519,15 @@ class HiveQuerySuite extends HiveComparisonTest {
val testKey = "spark.sql.key.usedfortestonly"
val testVal = "test.val.0"
val nonexistentKey = "nonexistent"
-
+ val KV = "([^=]+)=([^=]*)".r
+ def collectResults(rdd: SchemaRDD): Set[(String, String)] =
+ rdd.collect().map {
+ case Row(key: String, value: String) => key -> value
+ case Row(KV(key, value)) => key -> value
+ }.toSet
clear()
- // "set" itself returns all config variables currently specified in SQLConf.
+ // "SET" itself returns all config variables currently specified in SQLConf.
// TODO: Should we be listing the default here always? probably...
assert(sql("SET").collect().size == 0)
@@ -532,46 +536,21 @@ class HiveQuerySuite extends HiveComparisonTest {
}
assert(hiveconf.get(testKey, "") == testVal)
- assertResult(Array(s"$testKey=$testVal")) {
- sql(s"SET $testKey=$testVal").collect().map(_.getString(0))
- }
+ assertResult(Set(testKey -> testVal))(collectResults(sql("SET")))
+ assertResult(Set(testKey -> testVal))(collectResults(sql("SET -v")))
sql(s"SET ${testKey + testKey}=${testVal + testVal}")
assert(hiveconf.get(testKey + testKey, "") == testVal + testVal)
assertResult(Array(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) {
sql(s"SET").collect().map(_.getString(0))
}
-
- // "set key"
- assertResult(Array(s"$testKey=$testVal")) {
- sql(s"SET $testKey").collect().map(_.getString(0))
- }
-
- assertResult(Array(s"$nonexistentKey=<undefined>")) {
- sql(s"SET $nonexistentKey").collect().map(_.getString(0))
+ assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) {
+ collectResults(sql("SET -v"))
}
- // Assert that sql() should have the same effects as sql() by repeating the above using sql().
- clear()
- assert(sql("SET").collect().size == 0)
-
- assertResult(Array(s"$testKey=$testVal")) {
- sql(s"SET $testKey=$testVal").collect().map(_.getString(0))
- }
-
- assert(hiveconf.get(testKey, "") == testVal)
- assertResult(Array(s"$testKey=$testVal")) {
- sql("SET").collect().map(_.getString(0))
- }
-
- sql(s"SET ${testKey + testKey}=${testVal + testVal}")
- assert(hiveconf.get(testKey + testKey, "") == testVal + testVal)
- assertResult(Array(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) {
- sql("SET").collect().map(_.getString(0))
- }
-
- assertResult(Array(s"$testKey=$testVal")) {
- sql(s"SET $testKey").collect().map(_.getString(0))
+ // "SET key"
+ assertResult(Set(testKey -> testVal)) {
+ collectResults(sql(s"SET $testKey"))
}
assertResult(Array(s"$nonexistentKey=<undefined>")) {