aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2017-03-07 09:21:58 -0800
committerXiao Li <gatorsmile@gmail.com>2017-03-07 09:21:58 -0800
commitc05baabf10dd4c808929b4ae7a6d118aba6dd665 (patch)
tree82bcafd601ad4d90279bf82ebbe7b6c9bec7b4cc /sql/core/src/main
parent030acdd1f06f49383079c306b63e874ad738851f (diff)
downloadspark-c05baabf10dd4c808929b4ae7a6d118aba6dd665.tar.gz
spark-c05baabf10dd4c808929b4ae7a6d118aba6dd665.tar.bz2
spark-c05baabf10dd4c808929b4ae7a6d118aba6dd665.zip
[SPARK-19765][SPARK-18549][SQL] UNCACHE TABLE should un-cache all cached plans that refer to this table
## What changes were proposed in this pull request? When un-cache a table, we should not only remove the cache entry for this table, but also un-cache any other cached plans that refer to this table. This PR also includes some refactors: 1. use `java.util.LinkedList` to store the cache entries, so that it's safer to remove elements while iterating 2. rename `invalidateCache` to `recacheByPlan`, which is more obvious about what it does. ## How was this patch tested? new regression test Author: Wenchen Fan <wenchen@databricks.com> Closes #17097 from cloud-fan/cache.
Diffstat (limited to 'sql/core/src/main')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala118
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala23
5 files changed, 79 insertions, 76 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index 80138510dc..0ea806d6cb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution
import java.util.concurrent.locks.ReentrantReadWriteLock
+import scala.collection.JavaConverters._
+
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.internal.Logging
@@ -45,7 +47,7 @@ case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation)
class CacheManager extends Logging {
@transient
- private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData]
+ private val cachedData = new java.util.LinkedList[CachedData]
@transient
private val cacheLock = new ReentrantReadWriteLock
@@ -70,7 +72,7 @@ class CacheManager extends Logging {
/** Clears all cached tables. */
def clearCache(): Unit = writeLock {
- cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
+ cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
cachedData.clear()
}
@@ -88,46 +90,81 @@ class CacheManager extends Logging {
query: Dataset[_],
tableName: Option[String] = None,
storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock {
- val planToCache = query.queryExecution.analyzed
+ val planToCache = query.logicalPlan
if (lookupCachedData(planToCache).nonEmpty) {
logWarning("Asked to cache already cached data.")
} else {
val sparkSession = query.sparkSession
- cachedData +=
- CachedData(
- planToCache,
- InMemoryRelation(
- sparkSession.sessionState.conf.useCompression,
- sparkSession.sessionState.conf.columnBatchSize,
- storageLevel,
- sparkSession.sessionState.executePlan(planToCache).executedPlan,
- tableName))
+ cachedData.add(CachedData(
+ planToCache,
+ InMemoryRelation(
+ sparkSession.sessionState.conf.useCompression,
+ sparkSession.sessionState.conf.columnBatchSize,
+ storageLevel,
+ sparkSession.sessionState.executePlan(planToCache).executedPlan,
+ tableName)))
}
}
/**
- * Tries to remove the data for the given [[Dataset]] from the cache.
- * No operation, if it's already uncached.
+ * Un-cache all the cache entries that refer to the given plan.
+ */
+ def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock {
+ uncacheQuery(query.sparkSession, query.logicalPlan, blocking)
+ }
+
+ /**
+ * Un-cache all the cache entries that refer to the given plan.
*/
- def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Boolean = writeLock {
- val planToCache = query.queryExecution.analyzed
- val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
- val found = dataIndex >= 0
- if (found) {
- cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
- cachedData.remove(dataIndex)
+ def uncacheQuery(spark: SparkSession, plan: LogicalPlan, blocking: Boolean): Unit = writeLock {
+ val it = cachedData.iterator()
+ while (it.hasNext) {
+ val cd = it.next()
+ if (cd.plan.find(_.sameResult(plan)).isDefined) {
+ cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
+ it.remove()
+ }
}
- found
+ }
+
+ /**
+ * Tries to re-cache all the cache entries that refer to the given plan.
+ */
+ def recacheByPlan(spark: SparkSession, plan: LogicalPlan): Unit = writeLock {
+ recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined)
+ }
+
+ private def recacheByCondition(spark: SparkSession, condition: LogicalPlan => Boolean): Unit = {
+ val it = cachedData.iterator()
+ val needToRecache = scala.collection.mutable.ArrayBuffer.empty[CachedData]
+ while (it.hasNext) {
+ val cd = it.next()
+ if (condition(cd.plan)) {
+ cd.cachedRepresentation.cachedColumnBuffers.unpersist()
+ // Remove the cache entry before we create a new one, so that we can have a different
+ // physical plan.
+ it.remove()
+ val newCache = InMemoryRelation(
+ useCompression = cd.cachedRepresentation.useCompression,
+ batchSize = cd.cachedRepresentation.batchSize,
+ storageLevel = cd.cachedRepresentation.storageLevel,
+ child = spark.sessionState.executePlan(cd.plan).executedPlan,
+ tableName = cd.cachedRepresentation.tableName)
+ needToRecache += cd.copy(cachedRepresentation = newCache)
+ }
+ }
+
+ needToRecache.foreach(cachedData.add)
}
/** Optionally returns cached data for the given [[Dataset]] */
def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock {
- lookupCachedData(query.queryExecution.analyzed)
+ lookupCachedData(query.logicalPlan)
}
/** Optionally returns cached data for the given [[LogicalPlan]]. */
def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock {
- cachedData.find(cd => plan.sameResult(cd.plan))
+ cachedData.asScala.find(cd => plan.sameResult(cd.plan))
}
/** Replaces segments of the given logical plan with cached versions where possible. */
@@ -145,40 +182,17 @@ class CacheManager extends Logging {
}
/**
- * Invalidates the cache of any data that contains `plan`. Note that it is possible that this
- * function will over invalidate.
- */
- def invalidateCache(plan: LogicalPlan): Unit = writeLock {
- cachedData.foreach {
- case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty =>
- data.cachedRepresentation.recache()
- case _ =>
- }
- }
-
- /**
- * Invalidates the cache of any data that contains `resourcePath` in one or more
+ * Tries to re-cache all the cache entries that contain `resourcePath` in one or more
* `HadoopFsRelation` node(s) as part of its logical plan.
*/
- def invalidateCachedPath(
- sparkSession: SparkSession, resourcePath: String): Unit = writeLock {
+ def recacheByPath(spark: SparkSession, resourcePath: String): Unit = writeLock {
val (fs, qualifiedPath) = {
val path = new Path(resourcePath)
- val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf())
- (fs, path.makeQualified(fs.getUri, fs.getWorkingDirectory))
+ val fs = path.getFileSystem(spark.sessionState.newHadoopConf())
+ (fs, fs.makeQualified(path))
}
- cachedData.filter {
- case data if data.plan.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined => true
- case _ => false
- }.foreach { data =>
- val dataIndex = cachedData.indexWhere(cd => data.plan.sameResult(cd.plan))
- if (dataIndex >= 0) {
- data.cachedRepresentation.cachedColumnBuffers.unpersist(blocking = true)
- cachedData.remove(dataIndex)
- }
- sparkSession.sharedState.cacheManager.cacheQuery(Dataset.ofRows(sparkSession, data.plan))
- }
+ recacheByCondition(spark, _.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
index 37bd95e737..36037ac003 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
@@ -85,12 +85,6 @@ case class InMemoryRelation(
buildBuffers()
}
- def recache(): Unit = {
- _cachedColumnBuffers.unpersist()
- _cachedColumnBuffers = null
- buildBuffers()
- }
-
private def buildBuffers(): Unit = {
val output = child.output
val cached = child.execute().mapPartitionsInternal { rowIterator =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index b5c6042351..9d3c55060d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -199,8 +199,7 @@ case class DropTableCommand(
}
}
try {
- sparkSession.sharedState.cacheManager.uncacheQuery(
- sparkSession.table(tableName.quotedString))
+ sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
} catch {
case _: NoSuchTableException if ifExists =>
case NonFatal(e) => log.warn(e.toString, e)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala
index b2ff68a833..a813829d50 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala
@@ -42,8 +42,9 @@ case class InsertIntoDataSourceCommand(
val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
relation.insert(df, overwrite)
- // Invalidate the cache.
- sparkSession.sharedState.cacheManager.invalidateCache(logicalRelation)
+ // Re-cache all cached plans(including this relation itself, if it's cached) that refer to this
+ // data source relation.
+ sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, logicalRelation)
Seq.empty[Row]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
index ed07ff3ff0..53374859f1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
@@ -343,8 +343,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
* @since 2.0.0
*/
override def dropTempView(viewName: String): Boolean = {
- sparkSession.sessionState.catalog.getTempView(viewName).exists { tempView =>
- sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, tempView))
+ sparkSession.sessionState.catalog.getTempView(viewName).exists { viewDef =>
+ sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true)
sessionCatalog.dropTempView(viewName)
}
}
@@ -359,7 +359,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
*/
override def dropGlobalTempView(viewName: String): Boolean = {
sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { viewDef =>
- sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, viewDef))
+ sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true)
sessionCatalog.dropGlobalTempView(viewName)
}
}
@@ -404,7 +404,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
* @since 2.0.0
*/
override def uncacheTable(tableName: String): Unit = {
- sparkSession.sharedState.cacheManager.uncacheQuery(query = sparkSession.table(tableName))
+ sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
}
/**
@@ -442,17 +442,12 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
// If this table is cached as an InMemoryRelation, drop the original
// cached version and make the new version cached lazily.
- val logicalPlan = sparkSession.table(tableIdent).queryExecution.analyzed
- // Use lookupCachedData directly since RefreshTable also takes databaseName.
- val isCached = sparkSession.sharedState.cacheManager.lookupCachedData(logicalPlan).nonEmpty
- if (isCached) {
- // Create a data frame to represent the table.
- // TODO: Use uncacheTable once it supports database name.
- val df = Dataset.ofRows(sparkSession, logicalPlan)
+ val table = sparkSession.table(tableIdent)
+ if (isCached(table)) {
// Uncache the logicalPlan.
- sparkSession.sharedState.cacheManager.uncacheQuery(df, blocking = true)
+ sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true)
// Cache it again.
- sparkSession.sharedState.cacheManager.cacheQuery(df, Some(tableIdent.table))
+ sparkSession.sharedState.cacheManager.cacheQuery(table, Some(tableIdent.table))
}
}
@@ -464,7 +459,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
* @since 2.0.0
*/
override def refreshByPath(resourcePath: String): Unit = {
- sparkSession.sharedState.cacheManager.invalidateCachedPath(sparkSession, resourcePath)
+ sparkSession.sharedState.cacheManager.recacheByPath(sparkSession, resourcePath)
}
}