From c05baabf10dd4c808929b4ae7a6d118aba6dd665 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 7 Mar 2017 09:21:58 -0800 Subject: [SPARK-19765][SPARK-18549][SQL] UNCACHE TABLE should un-cache all cached plans that refer to this table ## What changes were proposed in this pull request? When un-cache a table, we should not only remove the cache entry for this table, but also un-cache any other cached plans that refer to this table. This PR also includes some refactors: 1. use `java.util.LinkedList` to store the cache entries, so that it's safer to remove elements while iterating 2. rename `invalidateCache` to `recacheByPlan`, which is more obvious about what it does. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #17097 from cloud-fan/cache. --- .../apache/spark/sql/execution/CacheManager.scala | 118 ++++++++++++--------- .../sql/execution/columnar/InMemoryRelation.scala | 6 -- .../apache/spark/sql/execution/command/ddl.scala | 3 +- .../datasources/InsertIntoDataSourceCommand.scala | 5 +- .../apache/spark/sql/internal/CatalogImpl.scala | 23 ++-- .../org/apache/spark/sql/CachedTableSuite.scala | 50 ++++++--- 6 files changed, 114 insertions(+), 91 deletions(-) (limited to 'sql/core/src') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 80138510dc..0ea806d6cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution import java.util.concurrent.locks.ReentrantReadWriteLock +import scala.collection.JavaConverters._ + import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging @@ -45,7 +47,7 @@ case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) class CacheManager extends Logging { @transient - private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData] + private val cachedData = new java.util.LinkedList[CachedData] @transient private val cacheLock = new ReentrantReadWriteLock @@ -70,7 +72,7 @@ class CacheManager extends Logging { /** Clears all cached tables. */ def clearCache(): Unit = writeLock { - cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) + cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) cachedData.clear() } @@ -88,46 +90,81 @@ class CacheManager extends Logging { query: Dataset[_], tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { - val planToCache = query.queryExecution.analyzed + val planToCache = query.logicalPlan if (lookupCachedData(planToCache).nonEmpty) { logWarning("Asked to cache already cached data.") } else { val sparkSession = query.sparkSession - cachedData += - CachedData( - planToCache, - InMemoryRelation( - sparkSession.sessionState.conf.useCompression, - sparkSession.sessionState.conf.columnBatchSize, - storageLevel, - sparkSession.sessionState.executePlan(planToCache).executedPlan, - tableName)) + cachedData.add(CachedData( + planToCache, + InMemoryRelation( + sparkSession.sessionState.conf.useCompression, + sparkSession.sessionState.conf.columnBatchSize, + storageLevel, + sparkSession.sessionState.executePlan(planToCache).executedPlan, + tableName))) } } /** - * Tries to remove the data for the given [[Dataset]] from the cache. - * No operation, if it's already uncached. + * Un-cache all the cache entries that refer to the given plan. + */ + def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock { + uncacheQuery(query.sparkSession, query.logicalPlan, blocking) + } + + /** + * Un-cache all the cache entries that refer to the given plan. */ - def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Boolean = writeLock { - val planToCache = query.queryExecution.analyzed - val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) - val found = dataIndex >= 0 - if (found) { - cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking) - cachedData.remove(dataIndex) + def uncacheQuery(spark: SparkSession, plan: LogicalPlan, blocking: Boolean): Unit = writeLock { + val it = cachedData.iterator() + while (it.hasNext) { + val cd = it.next() + if (cd.plan.find(_.sameResult(plan)).isDefined) { + cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking) + it.remove() + } } - found + } + + /** + * Tries to re-cache all the cache entries that refer to the given plan. + */ + def recacheByPlan(spark: SparkSession, plan: LogicalPlan): Unit = writeLock { + recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined) + } + + private def recacheByCondition(spark: SparkSession, condition: LogicalPlan => Boolean): Unit = { + val it = cachedData.iterator() + val needToRecache = scala.collection.mutable.ArrayBuffer.empty[CachedData] + while (it.hasNext) { + val cd = it.next() + if (condition(cd.plan)) { + cd.cachedRepresentation.cachedColumnBuffers.unpersist() + // Remove the cache entry before we create a new one, so that we can have a different + // physical plan. + it.remove() + val newCache = InMemoryRelation( + useCompression = cd.cachedRepresentation.useCompression, + batchSize = cd.cachedRepresentation.batchSize, + storageLevel = cd.cachedRepresentation.storageLevel, + child = spark.sessionState.executePlan(cd.plan).executedPlan, + tableName = cd.cachedRepresentation.tableName) + needToRecache += cd.copy(cachedRepresentation = newCache) + } + } + + needToRecache.foreach(cachedData.add) } /** Optionally returns cached data for the given [[Dataset]] */ def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock { - lookupCachedData(query.queryExecution.analyzed) + lookupCachedData(query.logicalPlan) } /** Optionally returns cached data for the given [[LogicalPlan]]. */ def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock { - cachedData.find(cd => plan.sameResult(cd.plan)) + cachedData.asScala.find(cd => plan.sameResult(cd.plan)) } /** Replaces segments of the given logical plan with cached versions where possible. */ @@ -145,40 +182,17 @@ class CacheManager extends Logging { } /** - * Invalidates the cache of any data that contains `plan`. Note that it is possible that this - * function will over invalidate. - */ - def invalidateCache(plan: LogicalPlan): Unit = writeLock { - cachedData.foreach { - case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty => - data.cachedRepresentation.recache() - case _ => - } - } - - /** - * Invalidates the cache of any data that contains `resourcePath` in one or more + * Tries to re-cache all the cache entries that contain `resourcePath` in one or more * `HadoopFsRelation` node(s) as part of its logical plan. */ - def invalidateCachedPath( - sparkSession: SparkSession, resourcePath: String): Unit = writeLock { + def recacheByPath(spark: SparkSession, resourcePath: String): Unit = writeLock { val (fs, qualifiedPath) = { val path = new Path(resourcePath) - val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf()) - (fs, path.makeQualified(fs.getUri, fs.getWorkingDirectory)) + val fs = path.getFileSystem(spark.sessionState.newHadoopConf()) + (fs, fs.makeQualified(path)) } - cachedData.filter { - case data if data.plan.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined => true - case _ => false - }.foreach { data => - val dataIndex = cachedData.indexWhere(cd => data.plan.sameResult(cd.plan)) - if (dataIndex >= 0) { - data.cachedRepresentation.cachedColumnBuffers.unpersist(blocking = true) - cachedData.remove(dataIndex) - } - sparkSession.sharedState.cacheManager.cacheQuery(Dataset.ofRows(sparkSession, data.plan)) - } + recacheByCondition(spark, _.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 37bd95e737..36037ac003 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -85,12 +85,6 @@ case class InMemoryRelation( buildBuffers() } - def recache(): Unit = { - _cachedColumnBuffers.unpersist() - _cachedColumnBuffers = null - buildBuffers() - } - private def buildBuffers(): Unit = { val output = child.output val cached = child.execute().mapPartitionsInternal { rowIterator => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index b5c6042351..9d3c55060d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -199,8 +199,7 @@ case class DropTableCommand( } } try { - sparkSession.sharedState.cacheManager.uncacheQuery( - sparkSession.table(tableName.quotedString)) + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) } catch { case _: NoSuchTableException if ifExists => case NonFatal(e) => log.warn(e.toString, e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index b2ff68a833..a813829d50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -42,8 +42,9 @@ case class InsertIntoDataSourceCommand( val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) relation.insert(df, overwrite) - // Invalidate the cache. - sparkSession.sharedState.cacheManager.invalidateCache(logicalRelation) + // Re-cache all cached plans(including this relation itself, if it's cached) that refer to this + // data source relation. + sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, logicalRelation) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index ed07ff3ff0..53374859f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -343,8 +343,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def dropTempView(viewName: String): Boolean = { - sparkSession.sessionState.catalog.getTempView(viewName).exists { tempView => - sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, tempView)) + sparkSession.sessionState.catalog.getTempView(viewName).exists { viewDef => + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true) sessionCatalog.dropTempView(viewName) } } @@ -359,7 +359,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def dropGlobalTempView(viewName: String): Boolean = { sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { viewDef => - sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, viewDef)) + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true) sessionCatalog.dropGlobalTempView(viewName) } } @@ -404,7 +404,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def uncacheTable(tableName: String): Unit = { - sparkSession.sharedState.cacheManager.uncacheQuery(query = sparkSession.table(tableName)) + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) } /** @@ -442,17 +442,12 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { // If this table is cached as an InMemoryRelation, drop the original // cached version and make the new version cached lazily. - val logicalPlan = sparkSession.table(tableIdent).queryExecution.analyzed - // Use lookupCachedData directly since RefreshTable also takes databaseName. - val isCached = sparkSession.sharedState.cacheManager.lookupCachedData(logicalPlan).nonEmpty - if (isCached) { - // Create a data frame to represent the table. - // TODO: Use uncacheTable once it supports database name. - val df = Dataset.ofRows(sparkSession, logicalPlan) + val table = sparkSession.table(tableIdent) + if (isCached(table)) { // Uncache the logicalPlan. - sparkSession.sharedState.cacheManager.uncacheQuery(df, blocking = true) + sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true) // Cache it again. - sparkSession.sharedState.cacheManager.cacheQuery(df, Some(tableIdent.table)) + sparkSession.sharedState.cacheManager.cacheQuery(table, Some(tableIdent.table)) } } @@ -464,7 +459,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def refreshByPath(resourcePath: String): Unit = { - sparkSession.sharedState.cacheManager.invalidateCachedPath(sparkSession, resourcePath) + sparkSession.sharedState.cacheManager.recacheByPath(sparkSession, resourcePath) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 2a0e088437..7a7d52b214 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -24,15 +24,15 @@ import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.CleanerListener -import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} -import org.apache.spark.util.AccumulatorContext +import org.apache.spark.util.{AccumulatorContext, Utils} private case class BigData(s: String) @@ -65,7 +65,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext maybeBlock.nonEmpty } - private def getNumInMemoryRelations(plan: LogicalPlan): Int = { + private def getNumInMemoryRelations(ds: Dataset[_]): Int = { + val plan = ds.queryExecution.withCachedData var sum = plan.collect { case _: InMemoryRelation => 1 }.sum plan.transformAllExpressions { case e: SubqueryExpression => @@ -187,7 +188,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext assertCached(spark.table("testData")) assertResult(1, "InMemoryRelation not found, testData should have been cached") { - getNumInMemoryRelations(spark.table("testData").queryExecution.withCachedData) + getNumInMemoryRelations(spark.table("testData")) } spark.catalog.cacheTable("testData") @@ -580,21 +581,21 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext localRelation.createOrReplaceTempView("localRelation") spark.catalog.cacheTable("localRelation") - assert(getNumInMemoryRelations(localRelation.queryExecution.withCachedData) == 1) + assert(getNumInMemoryRelations(localRelation) == 1) } test("SPARK-19093 Caching in side subquery") { withTempView("t1") { Seq(1).toDF("c1").createOrReplaceTempView("t1") spark.catalog.cacheTable("t1") - val cachedPlan = + val ds = sql( """ |SELECT * FROM t1 |WHERE |NOT EXISTS (SELECT * FROM t1) - """.stripMargin).queryExecution.optimizedPlan - assert(getNumInMemoryRelations(cachedPlan) == 2) + """.stripMargin) + assert(getNumInMemoryRelations(ds) == 2) } } @@ -610,17 +611,17 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext spark.catalog.cacheTable("t4") // Nested predicate subquery - val cachedPlan = + val ds = sql( """ |SELECT * FROM t1 |WHERE |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1)) - """.stripMargin).queryExecution.optimizedPlan - assert(getNumInMemoryRelations(cachedPlan) == 3) + """.stripMargin) + assert(getNumInMemoryRelations(ds) == 3) // Scalar subquery and predicate subquery - val cachedPlan2 = + val ds2 = sql( """ |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1) @@ -630,8 +631,27 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext |EXISTS (SELECT c1 FROM t3) |OR |c1 IN (SELECT c1 FROM t4) - """.stripMargin).queryExecution.optimizedPlan - assert(getNumInMemoryRelations(cachedPlan2) == 4) + """.stripMargin) + assert(getNumInMemoryRelations(ds2) == 4) + } + } + + test("SPARK-19765: UNCACHE TABLE should un-cache all cached plans that refer to this table") { + withTable("t") { + withTempPath { path => + Seq(1 -> "a").toDF("i", "j").write.parquet(path.getCanonicalPath) + sql(s"CREATE TABLE t USING parquet LOCATION '$path'") + spark.catalog.cacheTable("t") + spark.table("t").select($"i").cache() + checkAnswer(spark.table("t").select($"i"), Row(1)) + assertCached(spark.table("t").select($"i")) + + Utils.deleteRecursively(path) + spark.sessionState.catalog.refreshTable(TableIdentifier("t")) + spark.catalog.uncacheTable("t") + assert(spark.table("t").select($"i").count() == 0) + assert(getNumInMemoryRelations(spark.table("t").select($"i")) == 0) + } } } -- cgit v1.2.3