aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/test/scala/org
diff options
context:
space:
mode:
authorKunal Khamar <kkhamar@outlook.com>2017-03-08 13:06:22 -0800
committerShixiong Zhu <shixiong@databricks.com>2017-03-08 13:20:45 -0800
commit6570cfd7abe349dc6d2151f2ac9dc662e7465a79 (patch)
tree97b54a89a3d228c737203989d6b68db5ec75d8ef /sql/core/src/test/scala/org
parent1bf9012380de2aa7bdf39220b55748defde8b700 (diff)
downloadspark-6570cfd7abe349dc6d2151f2ac9dc662e7465a79.tar.gz
spark-6570cfd7abe349dc6d2151f2ac9dc662e7465a79.tar.bz2
spark-6570cfd7abe349dc6d2151f2ac9dc662e7465a79.zip
[SPARK-19540][SQL] Add ability to clone SparkSession wherein cloned session has an identical copy of the SessionState
Forking a newSession() from SparkSession currently makes a new SparkSession that does not retain SessionState (i.e. temporary tables, SQL config, registered functions etc.) This change adds a method cloneSession() which creates a new SparkSession with a copy of the parent's SessionState. Subsequent changes to base session are not propagated to cloned session, clone is independent after creation. If the base is changed after clone has been created, say user registers new UDF, then the new UDF will not be available inside the clone. Same goes for configs and temp tables. Unit tests Author: Kunal Khamar <kkhamar@outlook.com> Author: Shixiong Zhu <shixiong@databricks.com> Closes #16826 from kunalkhamar/fork-sparksession.
Diffstat (limited to 'sql/core/src/test/scala/org')
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala162
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala21
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala20
4 files changed, 209 insertions, 12 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala
new file mode 100644
index 0000000000..2d5e37242a
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+
+class SessionStateSuite extends SparkFunSuite
+ with BeforeAndAfterEach with BeforeAndAfterAll {
+
+ /**
+ * A shared SparkSession for all tests in this suite. Make sure you reset any changes to this
+ * session as this is a singleton HiveSparkSession in HiveSessionStateSuite and it's shared
+ * with all Hive test suites.
+ */
+ protected var activeSession: SparkSession = _
+
+ override def beforeAll(): Unit = {
+ activeSession = SparkSession.builder().master("local").getOrCreate()
+ }
+
+ override def afterAll(): Unit = {
+ if (activeSession != null) {
+ activeSession.stop()
+ activeSession = null
+ }
+ super.afterAll()
+ }
+
+ test("fork new session and inherit RuntimeConfig options") {
+ val key = "spark-config-clone"
+ try {
+ activeSession.conf.set(key, "active")
+
+ // inheritance
+ val forkedSession = activeSession.cloneSession()
+ assert(forkedSession ne activeSession)
+ assert(forkedSession.conf ne activeSession.conf)
+ assert(forkedSession.conf.get(key) == "active")
+
+ // independence
+ forkedSession.conf.set(key, "forked")
+ assert(activeSession.conf.get(key) == "active")
+ activeSession.conf.set(key, "dontcopyme")
+ assert(forkedSession.conf.get(key) == "forked")
+ } finally {
+ activeSession.conf.unset(key)
+ }
+ }
+
+ test("fork new session and inherit function registry and udf") {
+ val testFuncName1 = "strlenScala"
+ val testFuncName2 = "addone"
+ try {
+ activeSession.udf.register(testFuncName1, (_: String).length + (_: Int))
+ val forkedSession = activeSession.cloneSession()
+
+ // inheritance
+ assert(forkedSession ne activeSession)
+ assert(forkedSession.sessionState.functionRegistry ne
+ activeSession.sessionState.functionRegistry)
+ assert(forkedSession.sessionState.functionRegistry.lookupFunction(testFuncName1).nonEmpty)
+
+ // independence
+ forkedSession.sessionState.functionRegistry.dropFunction(testFuncName1)
+ assert(activeSession.sessionState.functionRegistry.lookupFunction(testFuncName1).nonEmpty)
+ activeSession.udf.register(testFuncName2, (_: Int) + 1)
+ assert(forkedSession.sessionState.functionRegistry.lookupFunction(testFuncName2).isEmpty)
+ } finally {
+ activeSession.sessionState.functionRegistry.dropFunction(testFuncName1)
+ activeSession.sessionState.functionRegistry.dropFunction(testFuncName2)
+ }
+ }
+
+ test("fork new session and inherit experimental methods") {
+ val originalExtraOptimizations = activeSession.experimental.extraOptimizations
+ val originalExtraStrategies = activeSession.experimental.extraStrategies
+ try {
+ object DummyRule1 extends Rule[LogicalPlan] {
+ def apply(p: LogicalPlan): LogicalPlan = p
+ }
+ object DummyRule2 extends Rule[LogicalPlan] {
+ def apply(p: LogicalPlan): LogicalPlan = p
+ }
+ val optimizations = List(DummyRule1, DummyRule2)
+ activeSession.experimental.extraOptimizations = optimizations
+ val forkedSession = activeSession.cloneSession()
+
+ // inheritance
+ assert(forkedSession ne activeSession)
+ assert(forkedSession.experimental ne activeSession.experimental)
+ assert(forkedSession.experimental.extraOptimizations.toSet ==
+ activeSession.experimental.extraOptimizations.toSet)
+
+ // independence
+ forkedSession.experimental.extraOptimizations = List(DummyRule2)
+ assert(activeSession.experimental.extraOptimizations == optimizations)
+ activeSession.experimental.extraOptimizations = List(DummyRule1)
+ assert(forkedSession.experimental.extraOptimizations == List(DummyRule2))
+ } finally {
+ activeSession.experimental.extraOptimizations = originalExtraOptimizations
+ activeSession.experimental.extraStrategies = originalExtraStrategies
+ }
+ }
+
+ test("fork new sessions and run query on inherited table") {
+ def checkTableExists(sparkSession: SparkSession): Unit = {
+ QueryTest.checkAnswer(sparkSession.sql(
+ """
+ |SELECT x.str, COUNT(*)
+ |FROM df x JOIN df y ON x.str = y.str
+ |GROUP BY x.str
+ """.stripMargin),
+ Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil)
+ }
+
+ val spark = activeSession
+ // Cannot use `import activeSession.implicits._` due to the compiler limitation.
+ import spark.implicits._
+
+ try {
+ activeSession
+ .createDataset[(Int, String)](Seq(1, 2, 3).map(i => (i, i.toString)))
+ .toDF("int", "str")
+ .createOrReplaceTempView("df")
+ checkTableExists(activeSession)
+
+ val forkedSession = activeSession.cloneSession()
+ assert(forkedSession ne activeSession)
+ assert(forkedSession.sessionState ne activeSession.sessionState)
+ checkTableExists(forkedSession)
+ checkTableExists(activeSession.cloneSession()) // ability to clone multiple times
+ checkTableExists(forkedSession.cloneSession()) // clone of clone
+ } finally {
+ activeSession.sql("drop table df")
+ }
+ }
+
+ test("fork new session and inherit reference to SharedState") {
+ val forkedSession = activeSession.cloneSession()
+ assert(activeSession.sharedState eq forkedSession.sharedState)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
index 989a7f2698..fcb8ffbc6e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
@@ -493,6 +493,25 @@ class CatalogSuite
}
}
- // TODO: add tests for the rest of them
+ test("clone Catalog") {
+ // need to test tempTables are cloned
+ assert(spark.catalog.listTables().collect().isEmpty)
+ createTempTable("my_temp_table")
+ assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table"))
+
+ // inheritance
+ val forkedSession = spark.cloneSession()
+ assert(spark ne forkedSession)
+ assert(spark.catalog ne forkedSession.catalog)
+ assert(forkedSession.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table"))
+
+ // independence
+ dropTable("my_temp_table") // drop table in original session
+ assert(spark.catalog.listTables().collect().map(_.name).toSet == Set())
+ assert(forkedSession.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table"))
+ forkedSession.sessionState.catalog
+ .createTempView("fork_table", Range(1, 2, 3, 4), overrideIfExists = true)
+ assert(spark.catalog.listTables().collect().map(_.name).toSet == Set())
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala
index 0e3a5ca9d7..f2456c7704 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala
@@ -187,4 +187,22 @@ class SQLConfEntrySuite extends SparkFunSuite {
}
assert(e2.getMessage === "The maximum size of the cache must not be negative")
}
+
+ test("clone SQLConf") {
+ val original = new SQLConf
+ val key = "spark.sql.SQLConfEntrySuite.clone"
+ assert(original.getConfString(key, "noentry") === "noentry")
+
+ // inheritance
+ original.setConfString(key, "orig")
+ val clone = original.clone()
+ assert(original ne clone)
+ assert(clone.getConfString(key, "noentry") === "orig")
+
+ // independence
+ clone.setConfString(key, "clone")
+ assert(original.getConfString(key, "noentry") === "orig")
+ original.setConfString(key, "dontcopyme")
+ assert(clone.getConfString(key, "noentry") === "clone")
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
index 8ab6db175d..898a2fb4f3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -35,18 +35,16 @@ private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) {
}
@transient
- override lazy val sessionState: SessionState = new SessionState(self) {
- override lazy val conf: SQLConf = {
- new SQLConf {
- clear()
- override def clear(): Unit = {
- super.clear()
- // Make sure we start with the default test configs even after clear
- TestSQLContext.overrideConfs.foreach { case (key, value) => setConfString(key, value) }
- }
+ override lazy val sessionState: SessionState = SessionState(
+ this,
+ new SQLConf {
+ clear()
+ override def clear(): Unit = {
+ super.clear()
+ // Make sure we start with the default test configs even after clear
+ TestSQLContext.overrideConfs.foreach { case (key, value) => setConfString(key, value) }
}
- }
- }
+ })
// Needed for Java tests
def loadTestData(): Unit = {