aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2016-05-05 14:36:47 -0700
committerAndrew Or <andrew@databricks.com>2016-05-05 14:36:47 -0700
commitbb9991dec5dd631b22a05e2e1b83b9082a845e8f (patch)
tree860c9a34ffc11fff35620e32d80061ef9128d3b2
parented6f3f8a5f3a6bf7c53e13c2798de398c9a526a6 (diff)
downloadspark-bb9991dec5dd631b22a05e2e1b83b9082a845e8f.tar.gz
spark-bb9991dec5dd631b22a05e2e1b83b9082a845e8f.tar.bz2
spark-bb9991dec5dd631b22a05e2e1b83b9082a845e8f.zip
[SPARK-15135][SQL] Make sure SparkSession thread safe
## What changes were proposed in this pull request? Went through SparkSession and its members and fixed non-thread-safe classes used by SparkSession ## How was this patch tested? Existing unit tests Author: Shixiong Zhu <shixiong@databricks.com> Closes #12915 from zsxwing/spark-session-thread-safe.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala102
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala7
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala4
6 files changed, 73 insertions, 56 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 1bada2ce67..ac05dd3d0e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -28,7 +28,11 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.util.StringKeyHashMap
-/** A catalog for looking up user defined functions, used by an [[Analyzer]]. */
+/**
+ * A catalog for looking up user defined functions, used by an [[Analyzer]].
+ *
+ * Note: The implementation should be thread-safe to allow concurrent access.
+ */
trait FunctionRegistry {
final def registerFunction(name: String, builder: FunctionBuilder): Unit = {
@@ -62,7 +66,7 @@ trait FunctionRegistry {
class SimpleFunctionRegistry extends FunctionRegistry {
- private[sql] val functionBuilders =
+ protected val functionBuilders =
StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false)
override def registerFunction(
@@ -97,7 +101,7 @@ class SimpleFunctionRegistry extends FunctionRegistry {
functionBuilders.remove(name).isDefined
}
- override def clear(): Unit = {
+ override def clear(): Unit = synchronized {
functionBuilders.clear()
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
index 1d2ca2863f..c65f461129 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
@@ -340,7 +340,7 @@ class InMemoryCatalog extends ExternalCatalog {
catalog(db).functions(funcName)
}
- override def functionExists(db: String, funcName: String): Boolean = {
+ override def functionExists(db: String, funcName: String): Boolean = synchronized {
requireDbExists(db)
catalog(db).functions.contains(funcName)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index eff420eb4c..712770784b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.catalog
+import javax.annotation.concurrent.GuardedBy
+
import scala.collection.mutable
import org.apache.hadoop.conf.Configuration
@@ -37,7 +39,7 @@ import org.apache.spark.sql.catalyst.util.StringUtils
* proxy to the underlying metastore (e.g. Hive Metastore) and it also manages temporary
* tables and functions of the Spark Session that it belongs to.
*
- * This class is not thread-safe.
+ * This class must be thread-safe.
*/
class SessionCatalog(
externalCatalog: ExternalCatalog,
@@ -66,12 +68,14 @@ class SessionCatalog(
}
/** List of temporary tables, mapping from table name to their logical plan. */
+ @GuardedBy("this")
protected val tempTables = new mutable.HashMap[String, LogicalPlan]
// Note: we track current database here because certain operations do not explicitly
// specify the database (e.g. DROP TABLE my_table). In these cases we must first
// check whether the temporary table or function exists, then, if not, operate on
// the corresponding item in the current database.
+ @GuardedBy("this")
protected var currentDb = {
val defaultName = "default"
val defaultDbDefinition =
@@ -137,13 +141,13 @@ class SessionCatalog(
externalCatalog.listDatabases(pattern)
}
- def getCurrentDatabase: String = currentDb
+ def getCurrentDatabase: String = synchronized { currentDb }
def setCurrentDatabase(db: String): Unit = {
if (!databaseExists(db)) {
throw new AnalysisException(s"Database '$db' does not exist.")
}
- currentDb = db
+ synchronized { currentDb = db }
}
/**
@@ -173,7 +177,7 @@ class SessionCatalog(
* If no such database is specified, create it in the current database.
*/
def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = {
- val db = tableDefinition.identifier.database.getOrElse(currentDb)
+ val db = tableDefinition.identifier.database.getOrElse(getCurrentDatabase)
val table = formatTableName(tableDefinition.identifier.table)
val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db)))
externalCatalog.createTable(db, newTableDefinition, ignoreIfExists)
@@ -189,7 +193,7 @@ class SessionCatalog(
* this becomes a no-op.
*/
def alterTable(tableDefinition: CatalogTable): Unit = {
- val db = tableDefinition.identifier.database.getOrElse(currentDb)
+ val db = tableDefinition.identifier.database.getOrElse(getCurrentDatabase)
val table = formatTableName(tableDefinition.identifier.table)
val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db)))
externalCatalog.alterTable(db, newTableDefinition)
@@ -201,7 +205,7 @@ class SessionCatalog(
* If the specified table is not found in the database then an [[AnalysisException]] is thrown.
*/
def getTableMetadata(name: TableIdentifier): CatalogTable = {
- val db = name.database.getOrElse(currentDb)
+ val db = name.database.getOrElse(getCurrentDatabase)
val table = formatTableName(name.table)
externalCatalog.getTable(db, table)
}
@@ -212,7 +216,7 @@ class SessionCatalog(
* If the specified table is not found in the database then return None if it doesn't exist.
*/
def getTableMetadataOption(name: TableIdentifier): Option[CatalogTable] = {
- val db = name.database.getOrElse(currentDb)
+ val db = name.database.getOrElse(getCurrentDatabase)
val table = formatTableName(name.table)
externalCatalog.getTableOption(db, table)
}
@@ -227,7 +231,7 @@ class SessionCatalog(
loadPath: String,
isOverwrite: Boolean,
holdDDLTime: Boolean): Unit = {
- val db = name.database.getOrElse(currentDb)
+ val db = name.database.getOrElse(getCurrentDatabase)
val table = formatTableName(name.table)
externalCatalog.loadTable(db, table, loadPath, isOverwrite, holdDDLTime)
}
@@ -245,14 +249,14 @@ class SessionCatalog(
holdDDLTime: Boolean,
inheritTableSpecs: Boolean,
isSkewedStoreAsSubdir: Boolean): Unit = {
- val db = name.database.getOrElse(currentDb)
+ val db = name.database.getOrElse(getCurrentDatabase)
val table = formatTableName(name.table)
externalCatalog.loadPartition(db, table, loadPath, partition, isOverwrite, holdDDLTime,
inheritTableSpecs, isSkewedStoreAsSubdir)
}
def defaultTablePath(tableIdent: TableIdentifier): String = {
- val dbName = tableIdent.database.getOrElse(currentDb)
+ val dbName = tableIdent.database.getOrElse(getCurrentDatabase)
val dbLocation = getDatabaseMetadata(dbName).locationUri
new Path(new Path(dbLocation), formatTableName(tableIdent.table)).toString
@@ -268,7 +272,7 @@ class SessionCatalog(
def createTempTable(
name: String,
tableDefinition: LogicalPlan,
- overrideIfExists: Boolean): Unit = {
+ overrideIfExists: Boolean): Unit = synchronized {
val table = formatTableName(name)
if (tempTables.contains(table) && !overrideIfExists) {
throw new AnalysisException(s"Temporary table '$name' already exists.")
@@ -285,7 +289,7 @@ class SessionCatalog(
*
* This assumes the database specified in `oldName` matches the one specified in `newName`.
*/
- def renameTable(oldName: TableIdentifier, newName: TableIdentifier): Unit = {
+ def renameTable(oldName: TableIdentifier, newName: TableIdentifier): Unit = synchronized {
val db = oldName.database.getOrElse(currentDb)
val newDb = newName.database.getOrElse(currentDb)
if (db != newDb) {
@@ -310,7 +314,7 @@ class SessionCatalog(
* If no database is specified, this will first attempt to drop a temporary table with
* the same name, then, if that does not exist, drop the table from the current database.
*/
- def dropTable(name: TableIdentifier, ignoreIfNotExists: Boolean): Unit = {
+ def dropTable(name: TableIdentifier, ignoreIfNotExists: Boolean): Unit = synchronized {
val db = name.database.getOrElse(currentDb)
val table = formatTableName(name.table)
if (name.database.isDefined || !tempTables.contains(table)) {
@@ -334,19 +338,21 @@ class SessionCatalog(
* the same name, then, if that does not exist, return the table from the current database.
*/
def lookupRelation(name: TableIdentifier, alias: Option[String] = None): LogicalPlan = {
- val db = name.database.getOrElse(currentDb)
- val table = formatTableName(name.table)
- val relation =
- if (name.database.isDefined || !tempTables.contains(table)) {
- val metadata = externalCatalog.getTable(db, table)
- SimpleCatalogRelation(db, metadata, alias)
- } else {
- tempTables(table)
- }
- val qualifiedTable = SubqueryAlias(table, relation)
- // If an alias was specified by the lookup, wrap the plan in a subquery so that
- // attributes are properly qualified with this alias.
- alias.map(a => SubqueryAlias(a, qualifiedTable)).getOrElse(qualifiedTable)
+ synchronized {
+ val db = name.database.getOrElse(currentDb)
+ val table = formatTableName(name.table)
+ val relation =
+ if (name.database.isDefined || !tempTables.contains(table)) {
+ val metadata = externalCatalog.getTable(db, table)
+ SimpleCatalogRelation(db, metadata, alias)
+ } else {
+ tempTables(table)
+ }
+ val qualifiedTable = SubqueryAlias(table, relation)
+ // If an alias was specified by the lookup, wrap the plan in a subquery so that
+ // attributes are properly qualified with this alias.
+ alias.map(a => SubqueryAlias(a, qualifiedTable)).getOrElse(qualifiedTable)
+ }
}
/**
@@ -357,7 +363,7 @@ class SessionCatalog(
* table with the same name, we will return false if the specified database does not
* contain the table.
*/
- def tableExists(name: TableIdentifier): Boolean = {
+ def tableExists(name: TableIdentifier): Boolean = synchronized {
val db = name.database.getOrElse(currentDb)
val table = formatTableName(name.table)
if (name.database.isDefined || !tempTables.contains(table)) {
@@ -373,7 +379,7 @@ class SessionCatalog(
* Note: The temporary table cache is checked only when database is not
* explicitly specified.
*/
- def isTemporaryTable(name: TableIdentifier): Boolean = {
+ def isTemporaryTable(name: TableIdentifier): Boolean = synchronized {
name.database.isEmpty && tempTables.contains(formatTableName(name.table))
}
@@ -388,9 +394,11 @@ class SessionCatalog(
def listTables(db: String, pattern: String): Seq[TableIdentifier] = {
val dbTables =
externalCatalog.listTables(db, pattern).map { t => TableIdentifier(t, Some(db)) }
- val _tempTables = StringUtils.filterPattern(tempTables.keys.toSeq, pattern)
- .map { t => TableIdentifier(t) }
- dbTables ++ _tempTables
+ synchronized {
+ val _tempTables = StringUtils.filterPattern(tempTables.keys.toSeq, pattern)
+ .map { t => TableIdentifier(t) }
+ dbTables ++ _tempTables
+ }
}
// TODO: It's strange that we have both refresh and invalidate here.
@@ -409,7 +417,7 @@ class SessionCatalog(
* Drop all existing temporary tables.
* For testing only.
*/
- def clearTempTables(): Unit = {
+ def clearTempTables(): Unit = synchronized {
tempTables.clear()
}
@@ -417,7 +425,7 @@ class SessionCatalog(
* Return a temporary table exactly as it was stored.
* For testing only.
*/
- private[catalog] def getTempTable(name: String): Option[LogicalPlan] = {
+ private[catalog] def getTempTable(name: String): Option[LogicalPlan] = synchronized {
tempTables.get(name)
}
@@ -441,7 +449,7 @@ class SessionCatalog(
tableName: TableIdentifier,
parts: Seq[CatalogTablePartition],
ignoreIfExists: Boolean): Unit = {
- val db = tableName.database.getOrElse(currentDb)
+ val db = tableName.database.getOrElse(getCurrentDatabase)
val table = formatTableName(tableName.table)
externalCatalog.createPartitions(db, table, parts, ignoreIfExists)
}
@@ -454,7 +462,7 @@ class SessionCatalog(
tableName: TableIdentifier,
parts: Seq[TablePartitionSpec],
ignoreIfNotExists: Boolean): Unit = {
- val db = tableName.database.getOrElse(currentDb)
+ val db = tableName.database.getOrElse(getCurrentDatabase)
val table = formatTableName(tableName.table)
externalCatalog.dropPartitions(db, table, parts, ignoreIfNotExists)
}
@@ -469,7 +477,7 @@ class SessionCatalog(
tableName: TableIdentifier,
specs: Seq[TablePartitionSpec],
newSpecs: Seq[TablePartitionSpec]): Unit = {
- val db = tableName.database.getOrElse(currentDb)
+ val db = tableName.database.getOrElse(getCurrentDatabase)
val table = formatTableName(tableName.table)
externalCatalog.renamePartitions(db, table, specs, newSpecs)
}
@@ -484,7 +492,7 @@ class SessionCatalog(
* this becomes a no-op.
*/
def alterPartitions(tableName: TableIdentifier, parts: Seq[CatalogTablePartition]): Unit = {
- val db = tableName.database.getOrElse(currentDb)
+ val db = tableName.database.getOrElse(getCurrentDatabase)
val table = formatTableName(tableName.table)
externalCatalog.alterPartitions(db, table, parts)
}
@@ -494,7 +502,7 @@ class SessionCatalog(
* If no database is specified, assume the table is in the current database.
*/
def getPartition(tableName: TableIdentifier, spec: TablePartitionSpec): CatalogTablePartition = {
- val db = tableName.database.getOrElse(currentDb)
+ val db = tableName.database.getOrElse(getCurrentDatabase)
val table = formatTableName(tableName.table)
externalCatalog.getPartition(db, table, spec)
}
@@ -509,7 +517,7 @@ class SessionCatalog(
def listPartitions(
tableName: TableIdentifier,
partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = {
- val db = tableName.database.getOrElse(currentDb)
+ val db = tableName.database.getOrElse(getCurrentDatabase)
val table = formatTableName(tableName.table)
externalCatalog.listPartitions(db, table, partialSpec)
}
@@ -532,7 +540,7 @@ class SessionCatalog(
* If no such database is specified, create it in the current database.
*/
def createFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = {
- val db = funcDefinition.identifier.database.getOrElse(currentDb)
+ val db = funcDefinition.identifier.database.getOrElse(getCurrentDatabase)
val identifier = FunctionIdentifier(funcDefinition.identifier.funcName, Some(db))
val newFuncDefinition = funcDefinition.copy(identifier = identifier)
if (!functionExists(identifier)) {
@@ -547,7 +555,7 @@ class SessionCatalog(
* If no database is specified, assume the function is in the current database.
*/
def dropFunction(name: FunctionIdentifier, ignoreIfNotExists: Boolean): Unit = {
- val db = name.database.getOrElse(currentDb)
+ val db = name.database.getOrElse(getCurrentDatabase)
val identifier = name.copy(database = Some(db))
if (functionExists(identifier)) {
// TODO: registry should just take in FunctionIdentifier for type safety
@@ -571,7 +579,7 @@ class SessionCatalog(
* If no database is specified, this will return the function in the current database.
*/
def getFunctionMetadata(name: FunctionIdentifier): CatalogFunction = {
- val db = name.database.getOrElse(currentDb)
+ val db = name.database.getOrElse(getCurrentDatabase)
externalCatalog.getFunction(db, name.funcName)
}
@@ -579,7 +587,7 @@ class SessionCatalog(
* Check if the specified function exists.
*/
def functionExists(name: FunctionIdentifier): Boolean = {
- val db = name.database.getOrElse(currentDb)
+ val db = name.database.getOrElse(getCurrentDatabase)
functionRegistry.functionExists(name.unquotedString) ||
externalCatalog.functionExists(db, name.funcName)
}
@@ -644,7 +652,7 @@ class SessionCatalog(
/**
* Look up the [[ExpressionInfo]] associated with the specified function, assuming it exists.
*/
- private[spark] def lookupFunctionInfo(name: FunctionIdentifier): ExpressionInfo = {
+ private[spark] def lookupFunctionInfo(name: FunctionIdentifier): ExpressionInfo = synchronized {
// TODO: just make function registry take in FunctionIdentifier instead of duplicating this
val qualifiedName = name.copy(database = name.database.orElse(Some(currentDb)))
functionRegistry.lookupFunction(name.funcName)
@@ -673,7 +681,9 @@ class SessionCatalog(
* based on the function class and put the builder into the FunctionRegistry.
* The name of this function in the FunctionRegistry will be `databaseName.functionName`.
*/
- def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = {
+ def lookupFunction(
+ name: FunctionIdentifier,
+ children: Seq[Expression]): Expression = synchronized {
// Note: the implementation of this function is a little bit convoluted.
// We probably shouldn't use a single FunctionRegistry to register all three kinds of functions
// (built-in, temp, and external).
@@ -741,7 +751,7 @@ class SessionCatalog(
*
* This is mainly used for tests.
*/
- private[sql] def reset(): Unit = {
+ private[sql] def reset(): Unit = synchronized {
val default = "default"
listDatabases().filter(_ != default).foreach { db =>
dropDatabase(db, ignoreIfNotExists = false, cascade = true)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
index c5df028485..a49da6dc2b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
@@ -42,9 +42,9 @@ class ExperimentalMethods private[sql]() {
* @since 1.3.0
*/
@Experimental
- var extraStrategies: Seq[Strategy] = Nil
+ @volatile var extraStrategies: Seq[Strategy] = Nil
@Experimental
- var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil
+ @volatile var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 9ed3756628..2a893c6478 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -115,14 +115,17 @@ class SparkSession private(
@transient
private var _wrapped: SQLContext = _
- protected[sql] def wrapped: SQLContext = {
+ @transient
+ private val _wrappedLock = new Object
+
+ protected[sql] def wrapped: SQLContext = _wrappedLock.synchronized {
if (_wrapped == null) {
_wrapped = new SQLContext(self, isRootContext = false)
}
_wrapped
}
- protected[sql] def setWrappedContext(sqlContext: SQLContext): Unit = {
+ protected[sql] def setWrappedContext(sqlContext: SQLContext): Unit = _wrappedLock.synchronized {
_wrapped = sqlContext
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index 42746ece3c..6d418c1dcf 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -485,11 +485,11 @@ private[hive] class TestHiveFunctionRegistry extends SimpleFunctionRegistry {
private val removedFunctions =
collection.mutable.ArrayBuffer.empty[(String, (ExpressionInfo, FunctionBuilder))]
- def unregisterFunction(name: String): Unit = {
+ def unregisterFunction(name: String): Unit = synchronized {
functionBuilders.remove(name).foreach(f => removedFunctions += name -> f)
}
- def restore(): Unit = {
+ def restore(): Unit = synchronized {
removedFunctions.foreach {
case (name, (info, builder)) => registerFunction(name, info, builder)
}