From bb9991dec5dd631b22a05e2e1b83b9082a845e8f Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 5 May 2016 14:36:47 -0700 Subject: [SPARK-15135][SQL] Make sure SparkSession thread safe ## What changes were proposed in this pull request? Went through SparkSession and its members and fixed non-thread-safe classes used by SparkSession ## How was this patch tested? Existing unit tests Author: Shixiong Zhu Closes #12915 from zsxwing/spark-session-thread-safe. --- .../sql/catalyst/analysis/FunctionRegistry.scala | 10 +- .../sql/catalyst/catalog/InMemoryCatalog.scala | 2 +- .../sql/catalyst/catalog/SessionCatalog.scala | 102 +++++++++++---------- .../org/apache/spark/sql/ExperimentalMethods.scala | 4 +- .../scala/org/apache/spark/sql/SparkSession.scala | 7 +- .../org/apache/spark/sql/hive/test/TestHive.scala | 4 +- 6 files changed, 73 insertions(+), 56 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 1bada2ce67..ac05dd3d0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -28,7 +28,11 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.util.StringKeyHashMap -/** A catalog for looking up user defined functions, used by an [[Analyzer]]. */ +/** + * A catalog for looking up user defined functions, used by an [[Analyzer]]. + * + * Note: The implementation should be thread-safe to allow concurrent access. + */ trait FunctionRegistry { final def registerFunction(name: String, builder: FunctionBuilder): Unit = { @@ -62,7 +66,7 @@ trait FunctionRegistry { class SimpleFunctionRegistry extends FunctionRegistry { - private[sql] val functionBuilders = + protected val functionBuilders = StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false) override def registerFunction( @@ -97,7 +101,7 @@ class SimpleFunctionRegistry extends FunctionRegistry { functionBuilders.remove(name).isDefined } - override def clear(): Unit = { + override def clear(): Unit = synchronized { functionBuilders.clear() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 1d2ca2863f..c65f461129 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -340,7 +340,7 @@ class InMemoryCatalog extends ExternalCatalog { catalog(db).functions(funcName) } - override def functionExists(db: String, funcName: String): Boolean = { + override def functionExists(db: String, funcName: String): Boolean = synchronized { requireDbExists(db) catalog(db).functions.contains(funcName) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index eff420eb4c..712770784b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.catalog +import javax.annotation.concurrent.GuardedBy + import scala.collection.mutable import org.apache.hadoop.conf.Configuration @@ -37,7 +39,7 @@ import org.apache.spark.sql.catalyst.util.StringUtils * proxy to the underlying metastore (e.g. Hive Metastore) and it also manages temporary * tables and functions of the Spark Session that it belongs to. * - * This class is not thread-safe. + * This class must be thread-safe. */ class SessionCatalog( externalCatalog: ExternalCatalog, @@ -66,12 +68,14 @@ class SessionCatalog( } /** List of temporary tables, mapping from table name to their logical plan. */ + @GuardedBy("this") protected val tempTables = new mutable.HashMap[String, LogicalPlan] // Note: we track current database here because certain operations do not explicitly // specify the database (e.g. DROP TABLE my_table). In these cases we must first // check whether the temporary table or function exists, then, if not, operate on // the corresponding item in the current database. + @GuardedBy("this") protected var currentDb = { val defaultName = "default" val defaultDbDefinition = @@ -137,13 +141,13 @@ class SessionCatalog( externalCatalog.listDatabases(pattern) } - def getCurrentDatabase: String = currentDb + def getCurrentDatabase: String = synchronized { currentDb } def setCurrentDatabase(db: String): Unit = { if (!databaseExists(db)) { throw new AnalysisException(s"Database '$db' does not exist.") } - currentDb = db + synchronized { currentDb = db } } /** @@ -173,7 +177,7 @@ class SessionCatalog( * If no such database is specified, create it in the current database. */ def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { - val db = tableDefinition.identifier.database.getOrElse(currentDb) + val db = tableDefinition.identifier.database.getOrElse(getCurrentDatabase) val table = formatTableName(tableDefinition.identifier.table) val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db))) externalCatalog.createTable(db, newTableDefinition, ignoreIfExists) @@ -189,7 +193,7 @@ class SessionCatalog( * this becomes a no-op. */ def alterTable(tableDefinition: CatalogTable): Unit = { - val db = tableDefinition.identifier.database.getOrElse(currentDb) + val db = tableDefinition.identifier.database.getOrElse(getCurrentDatabase) val table = formatTableName(tableDefinition.identifier.table) val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db))) externalCatalog.alterTable(db, newTableDefinition) @@ -201,7 +205,7 @@ class SessionCatalog( * If the specified table is not found in the database then an [[AnalysisException]] is thrown. */ def getTableMetadata(name: TableIdentifier): CatalogTable = { - val db = name.database.getOrElse(currentDb) + val db = name.database.getOrElse(getCurrentDatabase) val table = formatTableName(name.table) externalCatalog.getTable(db, table) } @@ -212,7 +216,7 @@ class SessionCatalog( * If the specified table is not found in the database then return None if it doesn't exist. */ def getTableMetadataOption(name: TableIdentifier): Option[CatalogTable] = { - val db = name.database.getOrElse(currentDb) + val db = name.database.getOrElse(getCurrentDatabase) val table = formatTableName(name.table) externalCatalog.getTableOption(db, table) } @@ -227,7 +231,7 @@ class SessionCatalog( loadPath: String, isOverwrite: Boolean, holdDDLTime: Boolean): Unit = { - val db = name.database.getOrElse(currentDb) + val db = name.database.getOrElse(getCurrentDatabase) val table = formatTableName(name.table) externalCatalog.loadTable(db, table, loadPath, isOverwrite, holdDDLTime) } @@ -245,14 +249,14 @@ class SessionCatalog( holdDDLTime: Boolean, inheritTableSpecs: Boolean, isSkewedStoreAsSubdir: Boolean): Unit = { - val db = name.database.getOrElse(currentDb) + val db = name.database.getOrElse(getCurrentDatabase) val table = formatTableName(name.table) externalCatalog.loadPartition(db, table, loadPath, partition, isOverwrite, holdDDLTime, inheritTableSpecs, isSkewedStoreAsSubdir) } def defaultTablePath(tableIdent: TableIdentifier): String = { - val dbName = tableIdent.database.getOrElse(currentDb) + val dbName = tableIdent.database.getOrElse(getCurrentDatabase) val dbLocation = getDatabaseMetadata(dbName).locationUri new Path(new Path(dbLocation), formatTableName(tableIdent.table)).toString @@ -268,7 +272,7 @@ class SessionCatalog( def createTempTable( name: String, tableDefinition: LogicalPlan, - overrideIfExists: Boolean): Unit = { + overrideIfExists: Boolean): Unit = synchronized { val table = formatTableName(name) if (tempTables.contains(table) && !overrideIfExists) { throw new AnalysisException(s"Temporary table '$name' already exists.") @@ -285,7 +289,7 @@ class SessionCatalog( * * This assumes the database specified in `oldName` matches the one specified in `newName`. */ - def renameTable(oldName: TableIdentifier, newName: TableIdentifier): Unit = { + def renameTable(oldName: TableIdentifier, newName: TableIdentifier): Unit = synchronized { val db = oldName.database.getOrElse(currentDb) val newDb = newName.database.getOrElse(currentDb) if (db != newDb) { @@ -310,7 +314,7 @@ class SessionCatalog( * If no database is specified, this will first attempt to drop a temporary table with * the same name, then, if that does not exist, drop the table from the current database. */ - def dropTable(name: TableIdentifier, ignoreIfNotExists: Boolean): Unit = { + def dropTable(name: TableIdentifier, ignoreIfNotExists: Boolean): Unit = synchronized { val db = name.database.getOrElse(currentDb) val table = formatTableName(name.table) if (name.database.isDefined || !tempTables.contains(table)) { @@ -334,19 +338,21 @@ class SessionCatalog( * the same name, then, if that does not exist, return the table from the current database. */ def lookupRelation(name: TableIdentifier, alias: Option[String] = None): LogicalPlan = { - val db = name.database.getOrElse(currentDb) - val table = formatTableName(name.table) - val relation = - if (name.database.isDefined || !tempTables.contains(table)) { - val metadata = externalCatalog.getTable(db, table) - SimpleCatalogRelation(db, metadata, alias) - } else { - tempTables(table) - } - val qualifiedTable = SubqueryAlias(table, relation) - // If an alias was specified by the lookup, wrap the plan in a subquery so that - // attributes are properly qualified with this alias. - alias.map(a => SubqueryAlias(a, qualifiedTable)).getOrElse(qualifiedTable) + synchronized { + val db = name.database.getOrElse(currentDb) + val table = formatTableName(name.table) + val relation = + if (name.database.isDefined || !tempTables.contains(table)) { + val metadata = externalCatalog.getTable(db, table) + SimpleCatalogRelation(db, metadata, alias) + } else { + tempTables(table) + } + val qualifiedTable = SubqueryAlias(table, relation) + // If an alias was specified by the lookup, wrap the plan in a subquery so that + // attributes are properly qualified with this alias. + alias.map(a => SubqueryAlias(a, qualifiedTable)).getOrElse(qualifiedTable) + } } /** @@ -357,7 +363,7 @@ class SessionCatalog( * table with the same name, we will return false if the specified database does not * contain the table. */ - def tableExists(name: TableIdentifier): Boolean = { + def tableExists(name: TableIdentifier): Boolean = synchronized { val db = name.database.getOrElse(currentDb) val table = formatTableName(name.table) if (name.database.isDefined || !tempTables.contains(table)) { @@ -373,7 +379,7 @@ class SessionCatalog( * Note: The temporary table cache is checked only when database is not * explicitly specified. */ - def isTemporaryTable(name: TableIdentifier): Boolean = { + def isTemporaryTable(name: TableIdentifier): Boolean = synchronized { name.database.isEmpty && tempTables.contains(formatTableName(name.table)) } @@ -388,9 +394,11 @@ class SessionCatalog( def listTables(db: String, pattern: String): Seq[TableIdentifier] = { val dbTables = externalCatalog.listTables(db, pattern).map { t => TableIdentifier(t, Some(db)) } - val _tempTables = StringUtils.filterPattern(tempTables.keys.toSeq, pattern) - .map { t => TableIdentifier(t) } - dbTables ++ _tempTables + synchronized { + val _tempTables = StringUtils.filterPattern(tempTables.keys.toSeq, pattern) + .map { t => TableIdentifier(t) } + dbTables ++ _tempTables + } } // TODO: It's strange that we have both refresh and invalidate here. @@ -409,7 +417,7 @@ class SessionCatalog( * Drop all existing temporary tables. * For testing only. */ - def clearTempTables(): Unit = { + def clearTempTables(): Unit = synchronized { tempTables.clear() } @@ -417,7 +425,7 @@ class SessionCatalog( * Return a temporary table exactly as it was stored. * For testing only. */ - private[catalog] def getTempTable(name: String): Option[LogicalPlan] = { + private[catalog] def getTempTable(name: String): Option[LogicalPlan] = synchronized { tempTables.get(name) } @@ -441,7 +449,7 @@ class SessionCatalog( tableName: TableIdentifier, parts: Seq[CatalogTablePartition], ignoreIfExists: Boolean): Unit = { - val db = tableName.database.getOrElse(currentDb) + val db = tableName.database.getOrElse(getCurrentDatabase) val table = formatTableName(tableName.table) externalCatalog.createPartitions(db, table, parts, ignoreIfExists) } @@ -454,7 +462,7 @@ class SessionCatalog( tableName: TableIdentifier, parts: Seq[TablePartitionSpec], ignoreIfNotExists: Boolean): Unit = { - val db = tableName.database.getOrElse(currentDb) + val db = tableName.database.getOrElse(getCurrentDatabase) val table = formatTableName(tableName.table) externalCatalog.dropPartitions(db, table, parts, ignoreIfNotExists) } @@ -469,7 +477,7 @@ class SessionCatalog( tableName: TableIdentifier, specs: Seq[TablePartitionSpec], newSpecs: Seq[TablePartitionSpec]): Unit = { - val db = tableName.database.getOrElse(currentDb) + val db = tableName.database.getOrElse(getCurrentDatabase) val table = formatTableName(tableName.table) externalCatalog.renamePartitions(db, table, specs, newSpecs) } @@ -484,7 +492,7 @@ class SessionCatalog( * this becomes a no-op. */ def alterPartitions(tableName: TableIdentifier, parts: Seq[CatalogTablePartition]): Unit = { - val db = tableName.database.getOrElse(currentDb) + val db = tableName.database.getOrElse(getCurrentDatabase) val table = formatTableName(tableName.table) externalCatalog.alterPartitions(db, table, parts) } @@ -494,7 +502,7 @@ class SessionCatalog( * If no database is specified, assume the table is in the current database. */ def getPartition(tableName: TableIdentifier, spec: TablePartitionSpec): CatalogTablePartition = { - val db = tableName.database.getOrElse(currentDb) + val db = tableName.database.getOrElse(getCurrentDatabase) val table = formatTableName(tableName.table) externalCatalog.getPartition(db, table, spec) } @@ -509,7 +517,7 @@ class SessionCatalog( def listPartitions( tableName: TableIdentifier, partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = { - val db = tableName.database.getOrElse(currentDb) + val db = tableName.database.getOrElse(getCurrentDatabase) val table = formatTableName(tableName.table) externalCatalog.listPartitions(db, table, partialSpec) } @@ -532,7 +540,7 @@ class SessionCatalog( * If no such database is specified, create it in the current database. */ def createFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = { - val db = funcDefinition.identifier.database.getOrElse(currentDb) + val db = funcDefinition.identifier.database.getOrElse(getCurrentDatabase) val identifier = FunctionIdentifier(funcDefinition.identifier.funcName, Some(db)) val newFuncDefinition = funcDefinition.copy(identifier = identifier) if (!functionExists(identifier)) { @@ -547,7 +555,7 @@ class SessionCatalog( * If no database is specified, assume the function is in the current database. */ def dropFunction(name: FunctionIdentifier, ignoreIfNotExists: Boolean): Unit = { - val db = name.database.getOrElse(currentDb) + val db = name.database.getOrElse(getCurrentDatabase) val identifier = name.copy(database = Some(db)) if (functionExists(identifier)) { // TODO: registry should just take in FunctionIdentifier for type safety @@ -571,7 +579,7 @@ class SessionCatalog( * If no database is specified, this will return the function in the current database. */ def getFunctionMetadata(name: FunctionIdentifier): CatalogFunction = { - val db = name.database.getOrElse(currentDb) + val db = name.database.getOrElse(getCurrentDatabase) externalCatalog.getFunction(db, name.funcName) } @@ -579,7 +587,7 @@ class SessionCatalog( * Check if the specified function exists. */ def functionExists(name: FunctionIdentifier): Boolean = { - val db = name.database.getOrElse(currentDb) + val db = name.database.getOrElse(getCurrentDatabase) functionRegistry.functionExists(name.unquotedString) || externalCatalog.functionExists(db, name.funcName) } @@ -644,7 +652,7 @@ class SessionCatalog( /** * Look up the [[ExpressionInfo]] associated with the specified function, assuming it exists. */ - private[spark] def lookupFunctionInfo(name: FunctionIdentifier): ExpressionInfo = { + private[spark] def lookupFunctionInfo(name: FunctionIdentifier): ExpressionInfo = synchronized { // TODO: just make function registry take in FunctionIdentifier instead of duplicating this val qualifiedName = name.copy(database = name.database.orElse(Some(currentDb))) functionRegistry.lookupFunction(name.funcName) @@ -673,7 +681,9 @@ class SessionCatalog( * based on the function class and put the builder into the FunctionRegistry. * The name of this function in the FunctionRegistry will be `databaseName.functionName`. */ - def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = { + def lookupFunction( + name: FunctionIdentifier, + children: Seq[Expression]): Expression = synchronized { // Note: the implementation of this function is a little bit convoluted. // We probably shouldn't use a single FunctionRegistry to register all three kinds of functions // (built-in, temp, and external). @@ -741,7 +751,7 @@ class SessionCatalog( * * This is mainly used for tests. */ - private[sql] def reset(): Unit = { + private[sql] def reset(): Unit = synchronized { val default = "default" listDatabases().filter(_ != default).foreach { db => dropDatabase(db, ignoreIfNotExists = false, cascade = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala index c5df028485..a49da6dc2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala @@ -42,9 +42,9 @@ class ExperimentalMethods private[sql]() { * @since 1.3.0 */ @Experimental - var extraStrategies: Seq[Strategy] = Nil + @volatile var extraStrategies: Seq[Strategy] = Nil @Experimental - var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil + @volatile var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 9ed3756628..2a893c6478 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -115,14 +115,17 @@ class SparkSession private( @transient private var _wrapped: SQLContext = _ - protected[sql] def wrapped: SQLContext = { + @transient + private val _wrappedLock = new Object + + protected[sql] def wrapped: SQLContext = _wrappedLock.synchronized { if (_wrapped == null) { _wrapped = new SQLContext(self, isRootContext = false) } _wrapped } - protected[sql] def setWrappedContext(sqlContext: SQLContext): Unit = { + protected[sql] def setWrappedContext(sqlContext: SQLContext): Unit = _wrappedLock.synchronized { _wrapped = sqlContext } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 42746ece3c..6d418c1dcf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -485,11 +485,11 @@ private[hive] class TestHiveFunctionRegistry extends SimpleFunctionRegistry { private val removedFunctions = collection.mutable.ArrayBuffer.empty[(String, (ExpressionInfo, FunctionBuilder))] - def unregisterFunction(name: String): Unit = { + def unregisterFunction(name: String): Unit = synchronized { functionBuilders.remove(name).foreach(f => removedFunctions += name -> f) } - def restore(): Unit = { + def restore(): Unit = synchronized { removedFunctions.foreach { case (name, (info, builder)) => registerFunction(name, info, builder) } -- cgit v1.2.3