aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Or <andrew@databricks.com>2015-08-03 10:58:37 -0700
committerTathagata Das <tathagata.das1565@gmail.com>2015-08-03 10:58:37 -0700
commitb41a32718d615b304efba146bf97be0229779b01 (patch)
tree657d1474da2a14485b6106cef8089af775f86dbb
parent69f5a7c934ac553ed52c00679b800bcffe83c1d6 (diff)
downloadspark-b41a32718d615b304efba146bf97be0229779b01.tar.gz
spark-b41a32718d615b304efba146bf97be0229779b01.tar.bz2
spark-b41a32718d615b304efba146bf97be0229779b01.zip
[SPARK-1855] Local checkpointing
Certain use cases of Spark involve RDDs with long lineages that must be truncated periodically (e.g. GraphX). The existing way of doing it is through `rdd.checkpoint()`, which is expensive because it writes to HDFS. This patch provides an alternative to truncate lineages cheaply *without providing the same level of fault tolerance*. **Local checkpointing** writes checkpointed data to the local file system through the block manager. It is much faster than replicating to a reliable storage and provides the same semantics as long as executors do not fail. It is accessible through a new operator `rdd.localCheckpoint()` and leaves the old one unchanged. Users may even decide to combine the two and call the reliable one less frequently. The bulk of this patch involves refactoring the checkpointing interface to accept custom implementations of checkpointing. [Design doc](https://issues.apache.org/jira/secure/attachment/12741708/SPARK-7292-design.pdf). Author: Andrew Or <andrew@databricks.com> Closes #7279 from andrewor14/local-checkpoint and squashes the following commits: 729600f [Andrew Or] Oops, fix tests 34bc059 [Andrew Or] Avoid computing all partitions in local checkpoint e43bbb6 [Andrew Or] Merge branch 'master' of github.com:apache/spark into local-checkpoint 3be5aea [Andrew Or] Address comments bf846a6 [Andrew Or] Merge branch 'master' of github.com:apache/spark into local-checkpoint ab003a3 [Andrew Or] Fix compile c2e111b [Andrew Or] Address comments 33f167a [Andrew Or] Merge branch 'master' of github.com:apache/spark into local-checkpoint e908a42 [Andrew Or] Fix tests f5be0f3 [Andrew Or] Use MEMORY_AND_DISK as the default local checkpoint level a92657d [Andrew Or] Update a few comments e58e3e3 [Andrew Or] Merge branch 'master' of github.com:apache/spark into local-checkpoint 4eb6eb1 [Andrew Or] Merge branch 'master' of github.com:apache/spark into local-checkpoint 1bbe154 [Andrew Or] Simplify LocalCheckpointRDD 48a9996 [Andrew Or] Avoid traversing dependency tree + rewrite tests 62aba3f [Andrew Or] Merge branch 'master' of github.com:apache/spark into local-checkpoint db70dc2 [Andrew Or] Express local checkpointing through caching the original RDD 87d43c6 [Andrew Or] Merge branch 'master' of github.com:apache/spark into local-checkpoint c449b38 [Andrew Or] Fix style 4a182f3 [Andrew Or] Add fine-grained tests for local checkpointing 53b363b [Andrew Or] Rename a few more awkwardly named methods (minor) e4cf071 [Andrew Or] Simplify LocalCheckpointRDD + docs + clean ups 4880deb [Andrew Or] Fix style d096c67 [Andrew Or] Fix mima 172cb66 [Andrew Or] Fix mima? e53d964 [Andrew Or] Fix style 56831c5 [Andrew Or] Add a few warnings and clear exception messages 2e59646 [Andrew Or] Add local checkpoint clean up tests 4dbbab1 [Andrew Or] Refactor CheckpointSuite to test local checkpointing 4514dc9 [Andrew Or] Clean local checkpoint files through RDD cleanups 0477eec [Andrew Or] Rename a few methods with awkward names (minor) 2e902e5 [Andrew Or] First implementation of local checkpointing 8447454 [Andrew Or] Fix tests 4ac1896 [Andrew Or] Refactor checkpoint interface for modularity
-rw-r--r--core/src/main/scala/org/apache/spark/ContextCleaner.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala153
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala67
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala83
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala128
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala106
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala172
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala108
-rw-r--r--core/src/test/scala/org/apache/spark/CheckpointSuite.scala164
-rw-r--r--core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala61
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala330
-rw-r--r--project/MimaExcludes.scala9
14 files changed, 1085 insertions, 315 deletions
diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
index 37198d887b..d23c1533db 100644
--- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
@@ -22,7 +22,7 @@ import java.lang.ref.{ReferenceQueue, WeakReference}
import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.rdd.{RDDCheckpointData, RDD}
+import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData}
import org.apache.spark.util.Utils
/**
@@ -231,11 +231,14 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}
}
- /** Perform checkpoint cleanup. */
+ /**
+ * Clean up checkpoint files written to a reliable storage.
+ * Locally checkpointed files are cleaned up separately through RDD cleanups.
+ */
def doCleanCheckpoint(rddId: Int): Unit = {
try {
logDebug("Cleaning rdd checkpoint data " + rddId)
- RDDCheckpointData.clearRDDCheckpointData(sc, rddId)
+ ReliableRDDCheckpointData.cleanCheckpoint(sc, rddId)
listeners.foreach(_.checkpointCleaned(rddId))
logInfo("Cleaned rdd checkpoint data " + rddId)
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 6f336a7c29..4380cf45cc 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1192,7 +1192,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
protected[spark] def checkpointFile[T: ClassTag](path: String): RDD[T] = withScope {
- new CheckpointRDD[T](this, path)
+ new ReliableCheckpointRDD[T](this, path)
}
/** Build the union of a list of RDDs. */
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index b48836d5c8..5d2c551d58 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -59,6 +59,14 @@ object TaskContext {
* Unset the thread local TaskContext. Internal to Spark.
*/
protected[spark] def unset(): Unit = taskContext.remove()
+
+ /**
+ * Return an empty task context that is not actually used.
+ * Internal use only.
+ */
+ private[spark] def empty(): TaskContext = {
+ new TaskContextImpl(0, 0, 0, 0, null, null)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index e17bd47905..72fe215dae 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -17,156 +17,31 @@
package org.apache.spark.rdd
-import java.io.IOException
-
import scala.reflect.ClassTag
-import org.apache.hadoop.fs.Path
-
-import org.apache.spark._
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.util.{SerializableConfiguration, Utils}
+import org.apache.spark.{Partition, SparkContext, TaskContext}
+/**
+ * An RDD partition used to recover checkpointed data.
+ */
private[spark] class CheckpointRDDPartition(val index: Int) extends Partition
/**
- * This RDD represents a RDD checkpoint file (similar to HadoopRDD).
+ * An RDD that recovers checkpointed data from storage.
*/
-private[spark]
-class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String)
+private[spark] abstract class CheckpointRDD[T: ClassTag](@transient sc: SparkContext)
extends RDD[T](sc, Nil) {
- private val broadcastedConf = sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration))
-
- @transient private val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
-
- override def getCheckpointFile: Option[String] = Some(checkpointPath)
-
- override def getPartitions: Array[Partition] = {
- val cpath = new Path(checkpointPath)
- val numPartitions =
- // listStatus can throw exception if path does not exist.
- if (fs.exists(cpath)) {
- val dirContents = fs.listStatus(cpath).map(_.getPath)
- val partitionFiles = dirContents.filter(_.getName.startsWith("part-")).map(_.toString).sorted
- val numPart = partitionFiles.length
- if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
- ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) {
- throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
- }
- numPart
- } else 0
-
- Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i))
- }
-
- override def getPreferredLocations(split: Partition): Seq[String] = {
- val status = fs.getFileStatus(new Path(checkpointPath,
- CheckpointRDD.splitIdToFile(split.index)))
- val locations = fs.getFileBlockLocations(status, 0, status.getLen)
- locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
- }
-
- override def compute(split: Partition, context: TaskContext): Iterator[T] = {
- val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index))
- CheckpointRDD.readFromFile(file, broadcastedConf, context)
- }
-
// CheckpointRDD should not be checkpointed again
- override def checkpoint(): Unit = { }
override def doCheckpoint(): Unit = { }
-}
-
-private[spark] object CheckpointRDD extends Logging {
- def splitIdToFile(splitId: Int): String = {
- "part-%05d".format(splitId)
- }
-
- def writeToFile[T: ClassTag](
- path: String,
- broadcastedConf: Broadcast[SerializableConfiguration],
- blockSize: Int = -1
- )(ctx: TaskContext, iterator: Iterator[T]) {
- val env = SparkEnv.get
- val outputDir = new Path(path)
- val fs = outputDir.getFileSystem(broadcastedConf.value.value)
-
- val finalOutputName = splitIdToFile(ctx.partitionId)
- val finalOutputPath = new Path(outputDir, finalOutputName)
- val tempOutputPath =
- new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptNumber)
-
- if (fs.exists(tempOutputPath)) {
- throw new IOException("Checkpoint failed: temporary path " +
- tempOutputPath + " already exists")
- }
- val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
-
- val fileOutputStream = if (blockSize < 0) {
- fs.create(tempOutputPath, false, bufferSize)
- } else {
- // This is mainly for testing purpose
- fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize)
- }
- val serializer = env.serializer.newInstance()
- val serializeStream = serializer.serializeStream(fileOutputStream)
- Utils.tryWithSafeFinally {
- serializeStream.writeAll(iterator)
- } {
- serializeStream.close()
- }
-
- if (!fs.rename(tempOutputPath, finalOutputPath)) {
- if (!fs.exists(finalOutputPath)) {
- logInfo("Deleting tempOutputPath " + tempOutputPath)
- fs.delete(tempOutputPath, false)
- throw new IOException("Checkpoint failed: failed to save output of task: "
- + ctx.attemptNumber + " and final output path does not exist")
- } else {
- // Some other copy of this task must've finished before us and renamed it
- logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it")
- fs.delete(tempOutputPath, false)
- }
- }
- }
-
- def readFromFile[T](
- path: Path,
- broadcastedConf: Broadcast[SerializableConfiguration],
- context: TaskContext
- ): Iterator[T] = {
- val env = SparkEnv.get
- val fs = path.getFileSystem(broadcastedConf.value.value)
- val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
- val fileInputStream = fs.open(path, bufferSize)
- val serializer = env.serializer.newInstance()
- val deserializeStream = serializer.deserializeStream(fileInputStream)
-
- // Register an on-task-completion callback to close the input stream.
- context.addTaskCompletionListener(context => deserializeStream.close())
-
- deserializeStream.asIterator.asInstanceOf[Iterator[T]]
- }
+ override def checkpoint(): Unit = { }
+ override def localCheckpoint(): this.type = this
- // Test whether CheckpointRDD generate expected number of partitions despite
- // each split file having multiple blocks. This needs to be run on a
- // cluster (mesos or standalone) using HDFS.
- def main(args: Array[String]) {
- import org.apache.spark._
+ // Note: There is a bug in MiMa that complains about `AbstractMethodProblem`s in the
+ // base [[org.apache.spark.rdd.RDD]] class if we do not override the following methods.
+ // scalastyle:off
+ protected override def getPartitions: Array[Partition] = ???
+ override def compute(p: Partition, tc: TaskContext): Iterator[T] = ???
+ // scalastyle:on
- val Array(cluster, hdfsPath) = args
- val env = SparkEnv.get
- val sc = new SparkContext(cluster, "CheckpointRDD Test")
- val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000)
- val path = new Path(hdfsPath, "temp")
- val conf = SparkHadoopUtil.get.newConfiguration(new SparkConf())
- val fs = path.getFileSystem(conf)
- val broadcastedConf = sc.broadcast(new SerializableConfiguration(conf))
- sc.runJob(rdd, CheckpointRDD.writeToFile[Int](path.toString, broadcastedConf, 1024) _)
- val cpRDD = new CheckpointRDD[Int](sc, path.toString)
- assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same")
- assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same")
- fs.delete(path, true)
- }
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala
new file mode 100644
index 0000000000..daa5779d68
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.{Partition, SparkContext, SparkEnv, SparkException, TaskContext}
+import org.apache.spark.storage.RDDBlockId
+
+/**
+ * A dummy CheckpointRDD that exists to provide informative error messages during failures.
+ *
+ * This is simply a placeholder because the original checkpointed RDD is expected to be
+ * fully cached. Only if an executor fails or if the user explicitly unpersists the original
+ * RDD will Spark ever attempt to compute this CheckpointRDD. When this happens, however,
+ * we must provide an informative error message.
+ *
+ * @param sc the active SparkContext
+ * @param rddId the ID of the checkpointed RDD
+ * @param numPartitions the number of partitions in the checkpointed RDD
+ */
+private[spark] class LocalCheckpointRDD[T: ClassTag](
+ @transient sc: SparkContext,
+ rddId: Int,
+ numPartitions: Int)
+ extends CheckpointRDD[T](sc) {
+
+ def this(rdd: RDD[T]) {
+ this(rdd.context, rdd.id, rdd.partitions.size)
+ }
+
+ protected override def getPartitions: Array[Partition] = {
+ (0 until numPartitions).toArray.map { i => new CheckpointRDDPartition(i) }
+ }
+
+ /**
+ * Throw an exception indicating that the relevant block is not found.
+ *
+ * This should only be called if the original RDD is explicitly unpersisted or if an
+ * executor is lost. Under normal circumstances, however, the original RDD (our child)
+ * is expected to be fully cached and so all partitions should already be computed and
+ * available in the block storage.
+ */
+ override def compute(partition: Partition, context: TaskContext): Iterator[T] = {
+ throw new SparkException(
+ s"Checkpoint block ${RDDBlockId(rddId, partition.index)} not found! Either the executor " +
+ s"that originally checkpointed this partition is no longer alive, or the original RDD is " +
+ s"unpersisted. If this problem persists, you may consider using `rdd.checkpoint()` " +
+ s"instead, which is slower than local checkpointing but more fault-tolerant.")
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala
new file mode 100644
index 0000000000..d6fad89684
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.{Logging, SparkEnv, SparkException, TaskContext}
+import org.apache.spark.storage.{RDDBlockId, StorageLevel}
+import org.apache.spark.util.Utils
+
+/**
+ * An implementation of checkpointing implemented on top of Spark's caching layer.
+ *
+ * Local checkpointing trades off fault tolerance for performance by skipping the expensive
+ * step of saving the RDD data to a reliable and fault-tolerant storage. Instead, the data
+ * is written to the local, ephemeral block storage that lives in each executor. This is useful
+ * for use cases where RDDs build up long lineages that need to be truncated often (e.g. GraphX).
+ */
+private[spark] class LocalRDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
+ extends RDDCheckpointData[T](rdd) with Logging {
+
+ /**
+ * Ensure the RDD is fully cached so the partitions can be recovered later.
+ */
+ protected override def doCheckpoint(): CheckpointRDD[T] = {
+ val level = rdd.getStorageLevel
+
+ // Assume storage level uses disk; otherwise memory eviction may cause data loss
+ assume(level.useDisk, s"Storage level $level is not appropriate for local checkpointing")
+
+ // Not all actions compute all partitions of the RDD (e.g. take). For correctness, we
+ // must cache any missing partitions. TODO: avoid running another job here (SPARK-8582).
+ val action = (tc: TaskContext, iterator: Iterator[T]) => Utils.getIteratorSize(iterator)
+ val missingPartitionIndices = rdd.partitions.map(_.index).filter { i =>
+ !SparkEnv.get.blockManager.master.contains(RDDBlockId(rdd.id, i))
+ }
+ if (missingPartitionIndices.nonEmpty) {
+ rdd.sparkContext.runJob(rdd, action, missingPartitionIndices)
+ }
+
+ new LocalCheckpointRDD[T](rdd)
+ }
+
+}
+
+private[spark] object LocalRDDCheckpointData {
+
+ val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK
+
+ /**
+ * Transform the specified storage level to one that uses disk.
+ *
+ * This guarantees that the RDD can be recomputed multiple times correctly as long as
+ * executors do not fail. Otherwise, if the RDD is cached in memory only, for instance,
+ * the checkpoint data will be lost if the relevant block is evicted from memory.
+ *
+ * This method is idempotent.
+ */
+ def transformStorageLevel(level: StorageLevel): StorageLevel = {
+ // If this RDD is to be cached off-heap, fail fast since we cannot provide any
+ // correctness guarantees about subsequent computations after the first one
+ if (level.useOffHeap) {
+ throw new SparkException("Local checkpointing is not compatible with off-heap caching.")
+ }
+
+ StorageLevel(useDisk = true, level.useMemory, level.deserialized, level.replication)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 6d61d22738..081c721f23 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -149,23 +149,43 @@ abstract class RDD[T: ClassTag](
}
/**
- * Set this RDD's storage level to persist its values across operations after the first time
- * it is computed. This can only be used to assign a new storage level if the RDD does not
- * have a storage level set yet..
+ * Mark this RDD for persisting using the specified level.
+ *
+ * @param newLevel the target storage level
+ * @param allowOverride whether to override any existing level with the new one
*/
- def persist(newLevel: StorageLevel): this.type = {
+ private def persist(newLevel: StorageLevel, allowOverride: Boolean): this.type = {
// TODO: Handle changes of StorageLevel
- if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) {
+ if (storageLevel != StorageLevel.NONE && newLevel != storageLevel && !allowOverride) {
throw new UnsupportedOperationException(
"Cannot change storage level of an RDD after it was already assigned a level")
}
- sc.persistRDD(this)
- // Register the RDD with the ContextCleaner for automatic GC-based cleanup
- sc.cleaner.foreach(_.registerRDDForCleanup(this))
+ // If this is the first time this RDD is marked for persisting, register it
+ // with the SparkContext for cleanups and accounting. Do this only once.
+ if (storageLevel == StorageLevel.NONE) {
+ sc.cleaner.foreach(_.registerRDDForCleanup(this))
+ sc.persistRDD(this)
+ }
storageLevel = newLevel
this
}
+ /**
+ * Set this RDD's storage level to persist its values across operations after the first time
+ * it is computed. This can only be used to assign a new storage level if the RDD does not
+ * have a storage level set yet. Local checkpointing is an exception.
+ */
+ def persist(newLevel: StorageLevel): this.type = {
+ if (isLocallyCheckpointed) {
+ // This means the user previously called localCheckpoint(), which should have already
+ // marked this RDD for persisting. Here we should override the old storage level with
+ // one that is explicitly requested by the user (after adapting it to use disk).
+ persist(LocalRDDCheckpointData.transformStorageLevel(newLevel), allowOverride = true)
+ } else {
+ persist(newLevel, allowOverride = false)
+ }
+ }
+
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def persist(): this.type = persist(StorageLevel.MEMORY_ONLY)
@@ -1448,33 +1468,99 @@ abstract class RDD[T: ClassTag](
/**
* Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint
- * directory set with SparkContext.setCheckpointDir() and all references to its parent
+ * directory set with `SparkContext#setCheckpointDir` and all references to its parent
* RDDs will be removed. This function must be called before any job has been
* executed on this RDD. It is strongly recommended that this RDD is persisted in
* memory, otherwise saving it on a file will require recomputation.
*/
- def checkpoint(): Unit = {
+ def checkpoint(): Unit = RDDCheckpointData.synchronized {
+ // NOTE: we use a global lock here due to complexities downstream with ensuring
+ // children RDD partitions point to the correct parent partitions. In the future
+ // we should revisit this consideration.
if (context.checkpointDir.isEmpty) {
throw new SparkException("Checkpoint directory has not been set in the SparkContext")
} else if (checkpointData.isEmpty) {
- // NOTE: we use a global lock here due to complexities downstream with ensuring
- // children RDD partitions point to the correct parent partitions. In the future
- // we should revisit this consideration.
- RDDCheckpointData.synchronized {
- checkpointData = Some(new RDDCheckpointData(this))
- }
+ checkpointData = Some(new ReliableRDDCheckpointData(this))
+ }
+ }
+
+ /**
+ * Mark this RDD for local checkpointing using Spark's existing caching layer.
+ *
+ * This method is for users who wish to truncate RDD lineages while skipping the expensive
+ * step of replicating the materialized data in a reliable distributed file system. This is
+ * useful for RDDs with long lineages that need to be truncated periodically (e.g. GraphX).
+ *
+ * Local checkpointing sacrifices fault-tolerance for performance. In particular, checkpointed
+ * data is written to ephemeral local storage in the executors instead of to a reliable,
+ * fault-tolerant storage. The effect is that if an executor fails during the computation,
+ * the checkpointed data may no longer be accessible, causing an irrecoverable job failure.
+ *
+ * This is NOT safe to use with dynamic allocation, which removes executors along
+ * with their cached blocks. If you must use both features, you are advised to set
+ * `spark.dynamicAllocation.cachedExecutorIdleTimeout` to a high value.
+ *
+ * The checkpoint directory set through `SparkContext#setCheckpointDir` is not used.
+ */
+ def localCheckpoint(): this.type = RDDCheckpointData.synchronized {
+ if (conf.getBoolean("spark.dynamicAllocation.enabled", false) &&
+ conf.contains("spark.dynamicAllocation.cachedExecutorIdleTimeout")) {
+ logWarning("Local checkpointing is NOT safe to use with dynamic allocation, " +
+ "which removes executors along with their cached blocks. If you must use both " +
+ "features, you are advised to set `spark.dynamicAllocation.cachedExecutorIdleTimeout` " +
+ "to a high value. E.g. If you plan to use the RDD for 1 hour, set the timeout to " +
+ "at least 1 hour.")
+ }
+
+ // Note: At this point we do not actually know whether the user will call persist() on
+ // this RDD later, so we must explicitly call it here ourselves to ensure the cached
+ // blocks are registered for cleanup later in the SparkContext.
+ //
+ // If, however, the user has already called persist() on this RDD, then we must adapt
+ // the storage level he/she specified to one that is appropriate for local checkpointing
+ // (i.e. uses disk) to guarantee correctness.
+
+ if (storageLevel == StorageLevel.NONE) {
+ persist(LocalRDDCheckpointData.DEFAULT_STORAGE_LEVEL)
+ } else {
+ persist(LocalRDDCheckpointData.transformStorageLevel(storageLevel), allowOverride = true)
}
+
+ checkpointData match {
+ case Some(reliable: ReliableRDDCheckpointData[_]) => logWarning(
+ "RDD was already marked for reliable checkpointing: overriding with local checkpoint.")
+ case _ =>
+ }
+ checkpointData = Some(new LocalRDDCheckpointData(this))
+ this
}
/**
- * Return whether this RDD has been checkpointed or not
+ * Return whether this RDD is marked for checkpointing, either reliably or locally.
*/
def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed)
/**
- * Gets the name of the file to which this RDD was checkpointed
+ * Return whether this RDD is marked for local checkpointing.
+ * Exposed for testing.
*/
- def getCheckpointFile: Option[String] = checkpointData.flatMap(_.getCheckpointFile)
+ private[rdd] def isLocallyCheckpointed: Boolean = {
+ checkpointData match {
+ case Some(_: LocalRDDCheckpointData[T]) => true
+ case _ => false
+ }
+ }
+
+ /**
+ * Gets the name of the directory to which this RDD was checkpointed.
+ * This is not defined if the RDD is checkpointed locally.
+ */
+ def getCheckpointFile: Option[String] = {
+ checkpointData match {
+ case Some(reliable: ReliableRDDCheckpointData[T]) => reliable.getCheckpointDir
+ case _ => None
+ }
+ }
// =======================================================================
// Other internal methods and fields
@@ -1545,7 +1631,7 @@ abstract class RDD[T: ClassTag](
if (!doCheckpointCalled) {
doCheckpointCalled = true
if (checkpointData.isDefined) {
- checkpointData.get.doCheckpoint()
+ checkpointData.get.checkpoint()
} else {
dependencies.foreach(_.rdd.doCheckpoint())
}
@@ -1557,7 +1643,7 @@ abstract class RDD[T: ClassTag](
* Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`)
* created from the checkpoint file, and forget its old dependencies and partitions.
*/
- private[spark] def markCheckpointed(checkpointRDD: RDD[_]) {
+ private[spark] def markCheckpointed(): Unit = {
clearDependencies()
partitions_ = null
deps = null // Forget the constructor argument for dependencies too
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
index 4f954363be..0e43520870 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
@@ -19,10 +19,7 @@ package org.apache.spark.rdd
import scala.reflect.ClassTag
-import org.apache.hadoop.fs.Path
-
-import org.apache.spark._
-import org.apache.spark.util.SerializableConfiguration
+import org.apache.spark.Partition
/**
* Enumeration to manage state transitions of an RDD through checkpointing
@@ -39,39 +36,31 @@ private[spark] object CheckpointState extends Enumeration {
* as well as, manages the post-checkpoint state by providing the updated partitions,
* iterator and preferred locations of the checkpointed RDD.
*/
-private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
- extends Logging with Serializable {
+private[spark] abstract class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
+ extends Serializable {
import CheckpointState._
// The checkpoint state of the associated RDD.
- private var cpState = Initialized
-
- // The file to which the associated RDD has been checkpointed to
- private var cpFile: Option[String] = None
+ protected var cpState = Initialized
- // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD.
- // This is defined if and only if `cpState` is `Checkpointed`.
+ // The RDD that contains our checkpointed data
private var cpRDD: Option[CheckpointRDD[T]] = None
// TODO: are we sure we need to use a global lock in the following methods?
- // Is the RDD already checkpointed
+ /**
+ * Return whether the checkpoint data for this RDD is already persisted.
+ */
def isCheckpointed: Boolean = RDDCheckpointData.synchronized {
cpState == Checkpointed
}
- // Get the file to which this RDD was checkpointed to as an Option
- def getCheckpointFile: Option[String] = RDDCheckpointData.synchronized {
- cpFile
- }
-
/**
- * Materialize this RDD and write its content to a reliable DFS.
+ * Materialize this RDD and persist its content.
* This is called immediately after the first action invoked on this RDD has completed.
*/
- def doCheckpoint(): Unit = {
-
+ final def checkpoint(): Unit = {
// Guard against multiple threads checkpointing the same RDD by
// atomically flipping the state of this RDDCheckpointData
RDDCheckpointData.synchronized {
@@ -82,64 +71,41 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
}
}
- // Create the output path for the checkpoint
- val path = RDDCheckpointData.rddCheckpointDataPath(rdd.context, rdd.id).get
- val fs = path.getFileSystem(rdd.context.hadoopConfiguration)
- if (!fs.mkdirs(path)) {
- throw new SparkException(s"Failed to create checkpoint path $path")
- }
-
- // Save to file, and reload it as an RDD
- val broadcastedConf = rdd.context.broadcast(
- new SerializableConfiguration(rdd.context.hadoopConfiguration))
- val newRDD = new CheckpointRDD[T](rdd.context, path.toString)
- if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) {
- rdd.context.cleaner.foreach { cleaner =>
- cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id)
- }
- }
-
- // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
- rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _)
- if (newRDD.partitions.length != rdd.partitions.length) {
- throw new SparkException(
- "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.length + ") has different " +
- "number of partitions than original RDD " + rdd + "(" + rdd.partitions.length + ")")
- }
+ val newRDD = doCheckpoint()
- // Change the dependencies and partitions of the RDD
+ // Update our state and truncate the RDD lineage
RDDCheckpointData.synchronized {
- cpFile = Some(path.toString)
cpRDD = Some(newRDD)
- rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
cpState = Checkpointed
+ rdd.markCheckpointed()
}
- logInfo(s"Done checkpointing RDD ${rdd.id} to $path, new parent is RDD ${newRDD.id}")
- }
-
- def getPartitions: Array[Partition] = RDDCheckpointData.synchronized {
- cpRDD.get.partitions
}
- def checkpointRDD: Option[CheckpointRDD[T]] = RDDCheckpointData.synchronized {
- cpRDD
- }
-}
+ /**
+ * Materialize this RDD and persist its content.
+ *
+ * Subclasses should override this method to define custom checkpointing behavior.
+ * @return the checkpoint RDD created in the process.
+ */
+ protected def doCheckpoint(): CheckpointRDD[T]
-private[spark] object RDDCheckpointData {
+ /**
+ * Return the RDD that contains our checkpointed data.
+ * This is only defined if the checkpoint state is `Checkpointed`.
+ */
+ def checkpointRDD: Option[CheckpointRDD[T]] = RDDCheckpointData.synchronized { cpRDD }
- /** Return the path of the directory to which this RDD's checkpoint data is written. */
- def rddCheckpointDataPath(sc: SparkContext, rddId: Int): Option[Path] = {
- sc.checkpointDir.map { dir => new Path(dir, s"rdd-$rddId") }
+ /**
+ * Return the partitions of the resulting checkpoint RDD.
+ * For tests only.
+ */
+ def getPartitions: Array[Partition] = RDDCheckpointData.synchronized {
+ cpRDD.map(_.partitions).getOrElse { Array.empty }
}
- /** Clean up the files associated with the checkpoint data for this RDD. */
- def clearRDDCheckpointData(sc: SparkContext, rddId: Int): Unit = {
- rddCheckpointDataPath(sc, rddId).foreach { path =>
- val fs = path.getFileSystem(sc.hadoopConfiguration)
- if (fs.exists(path)) {
- fs.delete(path, true)
- }
- }
- }
}
+
+/**
+ * Global lock for synchronizing checkpoint operations.
+ */
+private[spark] object RDDCheckpointData
diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
new file mode 100644
index 0000000000..35d8b0bfd1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import java.io.IOException
+
+import scala.reflect.ClassTag
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark._
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.util.{SerializableConfiguration, Utils}
+
+/**
+ * An RDD that reads from checkpoint files previously written to reliable storage.
+ */
+private[spark] class ReliableCheckpointRDD[T: ClassTag](
+ @transient sc: SparkContext,
+ val checkpointPath: String)
+ extends CheckpointRDD[T](sc) {
+
+ @transient private val hadoopConf = sc.hadoopConfiguration
+ @transient private val cpath = new Path(checkpointPath)
+ @transient private val fs = cpath.getFileSystem(hadoopConf)
+ private val broadcastedConf = sc.broadcast(new SerializableConfiguration(hadoopConf))
+
+ // Fail fast if checkpoint directory does not exist
+ require(fs.exists(cpath), s"Checkpoint directory does not exist: $checkpointPath")
+
+ /**
+ * Return the path of the checkpoint directory this RDD reads data from.
+ */
+ override def getCheckpointFile: Option[String] = Some(checkpointPath)
+
+ /**
+ * Return partitions described by the files in the checkpoint directory.
+ *
+ * Since the original RDD may belong to a prior application, there is no way to know a
+ * priori the number of partitions to expect. This method assumes that the original set of
+ * checkpoint files are fully preserved in a reliable storage across application lifespans.
+ */
+ protected override def getPartitions: Array[Partition] = {
+ // listStatus can throw exception if path does not exist.
+ val inputFiles = fs.listStatus(cpath)
+ .map(_.getPath)
+ .filter(_.getName.startsWith("part-"))
+ .sortBy(_.toString)
+ // Fail fast if input files are invalid
+ inputFiles.zipWithIndex.foreach { case (path, i) =>
+ if (!path.toString.endsWith(ReliableCheckpointRDD.checkpointFileName(i))) {
+ throw new SparkException(s"Invalid checkpoint file: $path")
+ }
+ }
+ Array.tabulate(inputFiles.length)(i => new CheckpointRDDPartition(i))
+ }
+
+ /**
+ * Return the locations of the checkpoint file associated with the given partition.
+ */
+ protected override def getPreferredLocations(split: Partition): Seq[String] = {
+ val status = fs.getFileStatus(
+ new Path(checkpointPath, ReliableCheckpointRDD.checkpointFileName(split.index)))
+ val locations = fs.getFileBlockLocations(status, 0, status.getLen)
+ locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
+ }
+
+ /**
+ * Read the content of the checkpoint file associated with the given partition.
+ */
+ override def compute(split: Partition, context: TaskContext): Iterator[T] = {
+ val file = new Path(checkpointPath, ReliableCheckpointRDD.checkpointFileName(split.index))
+ ReliableCheckpointRDD.readCheckpointFile(file, broadcastedConf, context)
+ }
+
+}
+
+private[spark] object ReliableCheckpointRDD extends Logging {
+
+ /**
+ * Return the checkpoint file name for the given partition.
+ */
+ private def checkpointFileName(partitionIndex: Int): String = {
+ "part-%05d".format(partitionIndex)
+ }
+
+ /**
+ * Write this partition's values to a checkpoint file.
+ */
+ def writeCheckpointFile[T: ClassTag](
+ path: String,
+ broadcastedConf: Broadcast[SerializableConfiguration],
+ blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
+ val env = SparkEnv.get
+ val outputDir = new Path(path)
+ val fs = outputDir.getFileSystem(broadcastedConf.value.value)
+
+ val finalOutputName = ReliableCheckpointRDD.checkpointFileName(ctx.partitionId())
+ val finalOutputPath = new Path(outputDir, finalOutputName)
+ val tempOutputPath =
+ new Path(outputDir, s".$finalOutputName-attempt-${ctx.attemptNumber()}")
+
+ if (fs.exists(tempOutputPath)) {
+ throw new IOException(s"Checkpoint failed: temporary path $tempOutputPath already exists")
+ }
+ val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
+
+ val fileOutputStream = if (blockSize < 0) {
+ fs.create(tempOutputPath, false, bufferSize)
+ } else {
+ // This is mainly for testing purpose
+ fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize)
+ }
+ val serializer = env.serializer.newInstance()
+ val serializeStream = serializer.serializeStream(fileOutputStream)
+ Utils.tryWithSafeFinally {
+ serializeStream.writeAll(iterator)
+ } {
+ serializeStream.close()
+ }
+
+ if (!fs.rename(tempOutputPath, finalOutputPath)) {
+ if (!fs.exists(finalOutputPath)) {
+ logInfo(s"Deleting tempOutputPath $tempOutputPath")
+ fs.delete(tempOutputPath, false)
+ throw new IOException("Checkpoint failed: failed to save output of task: " +
+ s"${ctx.attemptNumber()} and final output path does not exist: $finalOutputPath")
+ } else {
+ // Some other copy of this task must've finished before us and renamed it
+ logInfo(s"Final output path $finalOutputPath already exists; not overwriting it")
+ fs.delete(tempOutputPath, false)
+ }
+ }
+ }
+
+ /**
+ * Read the content of the specified checkpoint file.
+ */
+ def readCheckpointFile[T](
+ path: Path,
+ broadcastedConf: Broadcast[SerializableConfiguration],
+ context: TaskContext): Iterator[T] = {
+ val env = SparkEnv.get
+ val fs = path.getFileSystem(broadcastedConf.value.value)
+ val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
+ val fileInputStream = fs.open(path, bufferSize)
+ val serializer = env.serializer.newInstance()
+ val deserializeStream = serializer.deserializeStream(fileInputStream)
+
+ // Register an on-task-completion callback to close the input stream.
+ context.addTaskCompletionListener(context => deserializeStream.close())
+
+ deserializeStream.asIterator.asInstanceOf[Iterator[T]]
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala
new file mode 100644
index 0000000000..1df8eef5ff
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.reflect.ClassTag
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark._
+import org.apache.spark.util.SerializableConfiguration
+
+/**
+ * An implementation of checkpointing that writes the RDD data to reliable storage.
+ * This allows drivers to be restarted on failure with previously computed state.
+ */
+private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
+ extends RDDCheckpointData[T](rdd) with Logging {
+
+ // The directory to which the associated RDD has been checkpointed to
+ // This is assumed to be a non-local path that points to some reliable storage
+ private val cpDir: String =
+ ReliableRDDCheckpointData.checkpointPath(rdd.context, rdd.id)
+ .map(_.toString)
+ .getOrElse { throw new SparkException("Checkpoint dir must be specified.") }
+
+ /**
+ * Return the directory to which this RDD was checkpointed.
+ * If the RDD is not checkpointed yet, return None.
+ */
+ def getCheckpointDir: Option[String] = RDDCheckpointData.synchronized {
+ if (isCheckpointed) {
+ Some(cpDir.toString)
+ } else {
+ None
+ }
+ }
+
+ /**
+ * Materialize this RDD and write its content to a reliable DFS.
+ * This is called immediately after the first action invoked on this RDD has completed.
+ */
+ protected override def doCheckpoint(): CheckpointRDD[T] = {
+
+ // Create the output path for the checkpoint
+ val path = new Path(cpDir)
+ val fs = path.getFileSystem(rdd.context.hadoopConfiguration)
+ if (!fs.mkdirs(path)) {
+ throw new SparkException(s"Failed to create checkpoint path $cpDir")
+ }
+
+ // Save to file, and reload it as an RDD
+ val broadcastedConf = rdd.context.broadcast(
+ new SerializableConfiguration(rdd.context.hadoopConfiguration))
+ // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
+ rdd.context.runJob(rdd, ReliableCheckpointRDD.writeCheckpointFile[T](cpDir, broadcastedConf) _)
+ val newRDD = new ReliableCheckpointRDD[T](rdd.context, cpDir)
+ if (newRDD.partitions.length != rdd.partitions.length) {
+ throw new SparkException(
+ s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " +
+ s"number of partitions from original RDD $rdd(${rdd.partitions.length})")
+ }
+
+ // Optionally clean our checkpoint files if the reference is out of scope
+ if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) {
+ rdd.context.cleaner.foreach { cleaner =>
+ cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id)
+ }
+ }
+
+ logInfo(s"Done checkpointing RDD ${rdd.id} to $cpDir, new parent is RDD ${newRDD.id}")
+
+ newRDD
+ }
+
+}
+
+private[spark] object ReliableRDDCheckpointData {
+
+ /** Return the path of the directory to which this RDD's checkpoint data is written. */
+ def checkpointPath(sc: SparkContext, rddId: Int): Option[Path] = {
+ sc.checkpointDir.map { dir => new Path(dir, s"rdd-$rddId") }
+ }
+
+ /** Clean up the files associated with the checkpoint data for this RDD. */
+ def cleanCheckpoint(sc: SparkContext, rddId: Int): Unit = {
+ checkpointPath(sc, rddId).foreach { path =>
+ val fs = path.getFileSystem(sc.hadoopConfiguration)
+ if (fs.exists(path)) {
+ fs.delete(path, true)
+ }
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index cc50e6d79a..d343bb95cb 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -25,11 +25,15 @@ import org.apache.spark.rdd._
import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}
import org.apache.spark.util.Utils
+/**
+ * Test suite for end-to-end checkpointing functionality.
+ * This tests both reliable checkpoints and local checkpoints.
+ */
class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging {
- var checkpointDir: File = _
- val partitioner = new HashPartitioner(2)
+ private var checkpointDir: File = _
+ private val partitioner = new HashPartitioner(2)
- override def beforeEach() {
+ override def beforeEach(): Unit = {
super.beforeEach()
checkpointDir = File.createTempFile("temp", "", Utils.createTempDir())
checkpointDir.delete()
@@ -37,40 +41,43 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
sc.setCheckpointDir(checkpointDir.toString)
}
- override def afterEach() {
+ override def afterEach(): Unit = {
super.afterEach()
Utils.deleteRecursively(checkpointDir)
}
- test("basic checkpointing") {
+ runTest("basic checkpointing") { reliableCheckpoint: Boolean =>
val parCollection = sc.makeRDD(1 to 4)
val flatMappedRDD = parCollection.flatMap(x => 1 to x)
- flatMappedRDD.checkpoint()
+ checkpoint(flatMappedRDD, reliableCheckpoint)
assert(flatMappedRDD.dependencies.head.rdd === parCollection)
val result = flatMappedRDD.collect()
assert(flatMappedRDD.dependencies.head.rdd != parCollection)
assert(flatMappedRDD.collect() === result)
}
- test("RDDs with one-to-one dependencies") {
- testRDD(_.map(x => x.toString))
- testRDD(_.flatMap(x => 1 to x))
- testRDD(_.filter(_ % 2 == 0))
- testRDD(_.sample(false, 0.5, 0))
- testRDD(_.glom())
- testRDD(_.mapPartitions(_.map(_.toString)))
- testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
- testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
- testRDD(_.pipe(Seq("cat")))
+ runTest("RDDs with one-to-one dependencies") { reliableCheckpoint: Boolean =>
+ testRDD(_.map(x => x.toString), reliableCheckpoint)
+ testRDD(_.flatMap(x => 1 to x), reliableCheckpoint)
+ testRDD(_.filter(_ % 2 == 0), reliableCheckpoint)
+ testRDD(_.sample(false, 0.5, 0), reliableCheckpoint)
+ testRDD(_.glom(), reliableCheckpoint)
+ testRDD(_.mapPartitions(_.map(_.toString)), reliableCheckpoint)
+ testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString), reliableCheckpoint)
+ testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x),
+ reliableCheckpoint)
+ testRDD(_.pipe(Seq("cat")), reliableCheckpoint)
}
- test("ParallelCollection") {
+ runTest("ParallelCollectionRDD") { reliableCheckpoint: Boolean =>
val parCollection = sc.makeRDD(1 to 4, 2)
val numPartitions = parCollection.partitions.size
- parCollection.checkpoint()
+ checkpoint(parCollection, reliableCheckpoint)
assert(parCollection.dependencies === Nil)
val result = parCollection.collect()
- assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result)
+ if (reliableCheckpoint) {
+ assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result)
+ }
assert(parCollection.dependencies != Nil)
assert(parCollection.partitions.length === numPartitions)
assert(parCollection.partitions.toList ===
@@ -78,44 +85,46 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
assert(parCollection.collect() === result)
}
- test("BlockRDD") {
+ runTest("BlockRDD") { reliableCheckpoint: Boolean =>
val blockId = TestBlockId("id")
val blockManager = SparkEnv.get.blockManager
blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY)
val blockRDD = new BlockRDD[String](sc, Array(blockId))
val numPartitions = blockRDD.partitions.size
- blockRDD.checkpoint()
+ checkpoint(blockRDD, reliableCheckpoint)
val result = blockRDD.collect()
- assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result)
+ if (reliableCheckpoint) {
+ assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result)
+ }
assert(blockRDD.dependencies != Nil)
assert(blockRDD.partitions.length === numPartitions)
assert(blockRDD.partitions.toList === blockRDD.checkpointData.get.getPartitions.toList)
assert(blockRDD.collect() === result)
}
- test("ShuffledRDD") {
+ runTest("ShuffleRDD") { reliableCheckpoint: Boolean =>
testRDD(rdd => {
// Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD
new ShuffledRDD[Int, Int, Int](rdd.map(x => (x % 2, 1)), partitioner)
- })
+ }, reliableCheckpoint)
}
- test("UnionRDD") {
+ runTest("UnionRDD") { reliableCheckpoint: Boolean =>
def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1)
- testRDD(_.union(otherRDD))
- testRDDPartitions(_.union(otherRDD))
+ testRDD(_.union(otherRDD), reliableCheckpoint)
+ testRDDPartitions(_.union(otherRDD), reliableCheckpoint)
}
- test("CartesianRDD") {
+ runTest("CartesianRDD") { reliableCheckpoint: Boolean =>
def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1)
- testRDD(new CartesianRDD(sc, _, otherRDD))
- testRDDPartitions(new CartesianRDD(sc, _, otherRDD))
+ testRDD(new CartesianRDD(sc, _, otherRDD), reliableCheckpoint)
+ testRDDPartitions(new CartesianRDD(sc, _, otherRDD), reliableCheckpoint)
// Test that the CartesianRDD updates parent partitions (CartesianRDD.s1/s2) after
// the parent RDD has been checkpointed and parent partitions have been changed.
// Note that this test is very specific to the current implementation of CartesianRDD.
val ones = sc.makeRDD(1 to 100, 10).map(x => x)
- ones.checkpoint() // checkpoint that MappedRDD
+ checkpoint(ones, reliableCheckpoint) // checkpoint that MappedRDD
val cartesian = new CartesianRDD(sc, ones, ones)
val splitBeforeCheckpoint =
serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition])
@@ -129,16 +138,16 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
)
}
- test("CoalescedRDD") {
- testRDD(_.coalesce(2))
- testRDDPartitions(_.coalesce(2))
+ runTest("CoalescedRDD") { reliableCheckpoint: Boolean =>
+ testRDD(_.coalesce(2), reliableCheckpoint)
+ testRDDPartitions(_.coalesce(2), reliableCheckpoint)
// Test that the CoalescedRDDPartition updates parent partitions (CoalescedRDDPartition.parents)
// after the parent RDD has been checkpointed and parent partitions have been changed.
// Note that this test is very specific to the current implementation of
// CoalescedRDDPartitions.
val ones = sc.makeRDD(1 to 100, 10).map(x => x)
- ones.checkpoint() // checkpoint that MappedRDD
+ checkpoint(ones, reliableCheckpoint) // checkpoint that MappedRDD
val coalesced = new CoalescedRDD(ones, 2)
val splitBeforeCheckpoint =
serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition])
@@ -151,7 +160,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
)
}
- test("CoGroupedRDD") {
+ runTest("CoGroupedRDD") { reliableCheckpoint: Boolean =>
val longLineageRDD1 = generateFatPairRDD()
// Collect the RDD as sequences instead of arrays to enable equality tests in testRDD
@@ -160,26 +169,26 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
testRDD(rdd => {
CheckpointSuite.cogroup(longLineageRDD1, rdd.map(x => (x % 2, 1)), partitioner)
- }, seqCollectFunc)
+ }, reliableCheckpoint, seqCollectFunc)
val longLineageRDD2 = generateFatPairRDD()
testRDDPartitions(rdd => {
CheckpointSuite.cogroup(
longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner)
- }, seqCollectFunc)
+ }, reliableCheckpoint, seqCollectFunc)
}
- test("ZippedPartitionsRDD") {
- testRDD(rdd => rdd.zip(rdd.map(x => x)))
- testRDDPartitions(rdd => rdd.zip(rdd.map(x => x)))
+ runTest("ZippedPartitionsRDD") { reliableCheckpoint: Boolean =>
+ testRDD(rdd => rdd.zip(rdd.map(x => x)), reliableCheckpoint)
+ testRDDPartitions(rdd => rdd.zip(rdd.map(x => x)), reliableCheckpoint)
// Test that ZippedPartitionsRDD updates parent partitions after parent RDDs have
// been checkpointed and parent partitions have been changed.
// Note that this test is very specific to the implementation of ZippedPartitionsRDD.
val rdd = generateFatRDD()
val zippedRDD = rdd.zip(rdd.map(x => x)).asInstanceOf[ZippedPartitionsRDD2[_, _, _]]
- zippedRDD.rdd1.checkpoint()
- zippedRDD.rdd2.checkpoint()
+ checkpoint(zippedRDD.rdd1, reliableCheckpoint)
+ checkpoint(zippedRDD.rdd2, reliableCheckpoint)
val partitionBeforeCheckpoint =
serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartitionsPartition])
zippedRDD.count()
@@ -194,27 +203,27 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
)
}
- test("PartitionerAwareUnionRDD") {
+ runTest("PartitionerAwareUnionRDD") { reliableCheckpoint: Boolean =>
testRDD(rdd => {
new PartitionerAwareUnionRDD[(Int, Int)](sc, Array(
generateFatPairRDD(),
rdd.map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _)
))
- })
+ }, reliableCheckpoint)
testRDDPartitions(rdd => {
new PartitionerAwareUnionRDD[(Int, Int)](sc, Array(
generateFatPairRDD(),
rdd.map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _)
))
- })
+ }, reliableCheckpoint)
// Test that the PartitionerAwareUnionRDD updates parent partitions
// (PartitionerAwareUnionRDD.parents) after the parent RDD has been checkpointed and parent
// partitions have been changed. Note that this test is very specific to the current
// implementation of PartitionerAwareUnionRDD.
val pairRDD = generateFatPairRDD()
- pairRDD.checkpoint()
+ checkpoint(pairRDD, reliableCheckpoint)
val unionRDD = new PartitionerAwareUnionRDD(sc, Array(pairRDD))
val partitionBeforeCheckpoint = serializeDeserialize(
unionRDD.partitions.head.asInstanceOf[PartitionerAwareUnionRDDPartition])
@@ -228,17 +237,34 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
)
}
- test("CheckpointRDD with zero partitions") {
+ runTest("CheckpointRDD with zero partitions") { reliableCheckpoint: Boolean =>
val rdd = new BlockRDD[Int](sc, Array[BlockId]())
assert(rdd.partitions.size === 0)
assert(rdd.isCheckpointed === false)
- rdd.checkpoint()
+ checkpoint(rdd, reliableCheckpoint)
assert(rdd.count() === 0)
assert(rdd.isCheckpointed === true)
assert(rdd.partitions.size === 0)
}
- def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect()
+ // Utility test methods
+
+ /** Checkpoint the RDD either locally or reliably. */
+ private def checkpoint(rdd: RDD[_], reliableCheckpoint: Boolean): Unit = {
+ if (reliableCheckpoint) {
+ rdd.checkpoint()
+ } else {
+ rdd.localCheckpoint()
+ }
+ }
+
+ /** Run a test twice, once for local checkpointing and once for reliable checkpointing. */
+ private def runTest(name: String)(body: Boolean => Unit): Unit = {
+ test(name + " [reliable checkpoint]")(body(true))
+ test(name + " [local checkpoint]")(body(false))
+ }
+
+ private def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect()
/**
* Test checkpointing of the RDD generated by the given operation. It tests whether the
@@ -246,11 +272,14 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
* on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.).
*
* @param op an operation to run on the RDD
+ * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints
* @param collectFunc a function for collecting the values in the RDD, in case there are
* non-comparable types like arrays that we want to convert to something that supports ==
*/
- def testRDD[U: ClassTag](op: (RDD[Int]) => RDD[U],
- collectFunc: RDD[U] => Any = defaultCollectFunc[U] _) {
+ private def testRDD[U: ClassTag](
+ op: (RDD[Int]) => RDD[U],
+ reliableCheckpoint: Boolean,
+ collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = {
// Generate the final RDD using given RDD operation
val baseRDD = generateFatRDD()
val operatedRDD = op(baseRDD)
@@ -267,14 +296,16 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
// Find serialized sizes before and after the checkpoint
logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
- operatedRDD.checkpoint()
+ checkpoint(operatedRDD, reliableCheckpoint)
val result = collectFunc(operatedRDD)
operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables
val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
// Test whether the checkpoint file has been created
- assert(collectFunc(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result)
+ if (reliableCheckpoint) {
+ assert(collectFunc(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result)
+ }
// Test whether dependencies have been changed from its earlier parent RDD
assert(operatedRDD.dependencies.head.rdd != parentRDD)
@@ -310,11 +341,14 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
* partitions (i.e., do not call it on simple RDD like MappedRDD).
*
* @param op an operation to run on the RDD
+ * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints
* @param collectFunc a function for collecting the values in the RDD, in case there are
* non-comparable types like arrays that we want to convert to something that supports ==
*/
- def testRDDPartitions[U: ClassTag](op: (RDD[Int]) => RDD[U],
- collectFunc: RDD[U] => Any = defaultCollectFunc[U] _) {
+ private def testRDDPartitions[U: ClassTag](
+ op: (RDD[Int]) => RDD[U],
+ reliableCheckpoint: Boolean,
+ collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = {
// Generate the final RDD using given RDD operation
val baseRDD = generateFatRDD()
val operatedRDD = op(baseRDD)
@@ -328,7 +362,10 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
// Find serialized sizes before and after the checkpoint
logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
- parentRDDs.foreach(_.checkpoint()) // checkpoint the parent RDD, not the generated one
+ // checkpoint the parent RDD, not the generated one
+ parentRDDs.foreach { rdd =>
+ checkpoint(rdd, reliableCheckpoint)
+ }
val result = collectFunc(operatedRDD) // force checkpointing
operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables
val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
@@ -350,7 +387,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
/**
* Generate an RDD such that both the RDD and its partitions have large size.
*/
- def generateFatRDD(): RDD[Int] = {
+ private def generateFatRDD(): RDD[Int] = {
new FatRDD(sc.makeRDD(1 to 100, 4)).map(x => x)
}
@@ -358,7 +395,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
* Generate an pair RDD (with partitioner) such that both the RDD and its partitions
* have large size.
*/
- def generateFatPairRDD(): RDD[(Int, Int)] = {
+ private def generateFatPairRDD(): RDD[(Int, Int)] = {
new FatPairRDD(sc.makeRDD(1 to 100, 4), partitioner).mapValues(x => x)
}
@@ -366,7 +403,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
* Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks
* upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint.
*/
- def getSerializedSizes(rdd: RDD[_]): (Int, Int) = {
+ private def getSerializedSizes(rdd: RDD[_]): (Int, Int) = {
val rddSize = Utils.serialize(rdd).size
val rddCpDataSize = Utils.serialize(rdd.checkpointData).size
val rddPartitionSize = Utils.serialize(rdd.partitions).size
@@ -394,7 +431,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
* contents after deserialization (e.g., the contents of an RDD split after
* it is sent to a slave along with a task)
*/
- def serializeDeserialize[T](obj: T): T = {
+ private def serializeDeserialize[T](obj: T): T = {
val bytes = Utils.serialize(obj)
Utils.deserialize[T](bytes)
}
@@ -402,10 +439,11 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
/**
* Recursively force the initialization of the all members of an RDD and it parents.
*/
- def initializeRdd(rdd: RDD[_]) {
+ private def initializeRdd(rdd: RDD[_]): Unit = {
rdd.partitions // forces the
- rdd.dependencies.map(_.rdd).foreach(initializeRdd(_))
+ rdd.dependencies.map(_.rdd).foreach(initializeRdd)
}
+
}
/** RDD partition that has large serialized size. */
diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
index 26858ef277..0c14bef7be 100644
--- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
@@ -24,12 +24,11 @@ import scala.language.existentials
import scala.util.Random
import org.scalatest.BeforeAndAfter
-import org.scalatest.concurrent.{PatienceConfiguration, Eventually}
+import org.scalatest.concurrent.PatienceConfiguration
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
-import org.apache.spark.SparkContext._
-import org.apache.spark.rdd.{RDDCheckpointData, RDD}
+import org.apache.spark.rdd.{ReliableRDDCheckpointData, RDD}
import org.apache.spark.storage._
import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.shuffle.sort.SortShuffleManager
@@ -52,6 +51,7 @@ abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[Ha
.setAppName("ContextCleanerSuite")
.set("spark.cleaner.referenceTracking.blocking", "true")
.set("spark.cleaner.referenceTracking.blocking.shuffle", "true")
+ .set("spark.cleaner.referenceTracking.cleanCheckpoints", "true")
.set("spark.shuffle.manager", shuffleManager.getName)
before {
@@ -209,11 +209,11 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase {
postGCTester.assertCleanup()
}
- test("automatically cleanup checkpoint") {
+ test("automatically cleanup normal checkpoint") {
val checkpointDir = java.io.File.createTempFile("temp", "")
checkpointDir.deleteOnExit()
checkpointDir.delete()
- var rdd = newPairRDD
+ var rdd = newPairRDD()
sc.setCheckpointDir(checkpointDir.toString)
rdd.checkpoint()
rdd.cache()
@@ -221,23 +221,26 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase {
var rddId = rdd.id
// Confirm the checkpoint directory exists
- assert(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).isDefined)
- val path = RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get
+ assert(ReliableRDDCheckpointData.checkpointPath(sc, rddId).isDefined)
+ val path = ReliableRDDCheckpointData.checkpointPath(sc, rddId).get
val fs = path.getFileSystem(sc.hadoopConfiguration)
assert(fs.exists(path))
// the checkpoint is not cleaned by default (without the configuration set)
- var postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, Nil)
+ var postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, Seq(rddId))
rdd = null // Make RDD out of scope, ok if collected earlier
runGC()
postGCTester.assertCleanup()
- assert(fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get))
+ assert(!fs.exists(ReliableRDDCheckpointData.checkpointPath(sc, rddId).get))
+ // Verify that checkpoints are NOT cleaned up if the config is not enabled
sc.stop()
- val conf = new SparkConf().setMaster("local[2]").setAppName("cleanupCheckpoint").
- set("spark.cleaner.referenceTracking.cleanCheckpoints", "true")
+ val conf = new SparkConf()
+ .setMaster("local[2]")
+ .setAppName("cleanupCheckpoint")
+ .set("spark.cleaner.referenceTracking.cleanCheckpoints", "false")
sc = new SparkContext(conf)
- rdd = newPairRDD
+ rdd = newPairRDD()
sc.setCheckpointDir(checkpointDir.toString)
rdd.checkpoint()
rdd.cache()
@@ -245,17 +248,40 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase {
rddId = rdd.id
// Confirm the checkpoint directory exists
- assert(fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get))
+ assert(fs.exists(ReliableRDDCheckpointData.checkpointPath(sc, rddId).get))
// Reference rdd to defeat any early collection by the JVM
rdd.count()
// Test that GC causes checkpoint data cleanup after dereferencing the RDD
- postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, Seq(rddId))
+ postGCTester = new CleanerTester(sc, Seq(rddId))
rdd = null // Make RDD out of scope
runGC()
postGCTester.assertCleanup()
- assert(!fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get))
+ assert(fs.exists(ReliableRDDCheckpointData.checkpointPath(sc, rddId).get))
+ }
+
+ test("automatically clean up local checkpoint") {
+ // Note that this test is similar to the RDD cleanup
+ // test because the same underlying mechanism is used!
+ var rdd = newPairRDD().localCheckpoint()
+ assert(rdd.checkpointData.isDefined)
+ assert(rdd.checkpointData.get.checkpointRDD.isEmpty)
+ rdd.count()
+ assert(rdd.checkpointData.get.checkpointRDD.isDefined)
+
+ // Test that GC does not cause checkpoint cleanup due to a strong reference
+ val preGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id))
+ runGC()
+ intercept[Exception] {
+ preGCTester.assertCleanup()(timeout(1000 millis))
+ }
+
+ // Test that RDD going out of scope does cause the checkpoint blocks to be cleaned up
+ val postGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id))
+ rdd = null
+ runGC()
+ postGCTester.assertCleanup()
}
test("automatically cleanup RDD + shuffle + broadcast") {
@@ -408,7 +434,10 @@ class SortShuffleContextCleanerSuite extends ContextCleanerSuiteBase(classOf[Sor
}
-/** Class to test whether RDDs, shuffles, etc. have been successfully cleaned. */
+/**
+ * Class to test whether RDDs, shuffles, etc. have been successfully cleaned.
+ * The checkpoint here refers only to normal (reliable) checkpoints, not local checkpoints.
+ */
class CleanerTester(
sc: SparkContext,
rddIds: Seq[Int] = Seq.empty,
diff --git a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala
new file mode 100644
index 0000000000..5103eb74b2
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala
@@ -0,0 +1,330 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.{SparkException, SparkContext, LocalSparkContext, SparkFunSuite}
+
+import org.mockito.Mockito.spy
+import org.apache.spark.storage.{RDDBlockId, StorageLevel}
+
+/**
+ * Fine-grained tests for local checkpointing.
+ * For end-to-end tests, see CheckpointSuite.
+ */
+class LocalCheckpointSuite extends SparkFunSuite with LocalSparkContext {
+
+ override def beforeEach(): Unit = {
+ sc = new SparkContext("local[2]", "test")
+ }
+
+ test("transform storage level") {
+ val transform = LocalRDDCheckpointData.transformStorageLevel _
+ assert(transform(StorageLevel.NONE) === StorageLevel.DISK_ONLY)
+ assert(transform(StorageLevel.MEMORY_ONLY) === StorageLevel.MEMORY_AND_DISK)
+ assert(transform(StorageLevel.MEMORY_ONLY_SER) === StorageLevel.MEMORY_AND_DISK_SER)
+ assert(transform(StorageLevel.MEMORY_ONLY_2) === StorageLevel.MEMORY_AND_DISK_2)
+ assert(transform(StorageLevel.MEMORY_ONLY_SER_2) === StorageLevel.MEMORY_AND_DISK_SER_2)
+ assert(transform(StorageLevel.DISK_ONLY) === StorageLevel.DISK_ONLY)
+ assert(transform(StorageLevel.DISK_ONLY_2) === StorageLevel.DISK_ONLY_2)
+ assert(transform(StorageLevel.MEMORY_AND_DISK) === StorageLevel.MEMORY_AND_DISK)
+ assert(transform(StorageLevel.MEMORY_AND_DISK_SER) === StorageLevel.MEMORY_AND_DISK_SER)
+ assert(transform(StorageLevel.MEMORY_AND_DISK_2) === StorageLevel.MEMORY_AND_DISK_2)
+ assert(transform(StorageLevel.MEMORY_AND_DISK_SER_2) === StorageLevel.MEMORY_AND_DISK_SER_2)
+ // Off-heap is not supported and Spark should fail fast
+ intercept[SparkException] {
+ transform(StorageLevel.OFF_HEAP)
+ }
+ }
+
+ test("basic lineage truncation") {
+ val numPartitions = 4
+ val parallelRdd = sc.parallelize(1 to 100, numPartitions)
+ val mappedRdd = parallelRdd.map { i => i + 1 }
+ val filteredRdd = mappedRdd.filter { i => i % 2 == 0 }
+ val expectedPartitionIndices = (0 until numPartitions).toArray
+ assert(filteredRdd.checkpointData.isEmpty)
+ assert(filteredRdd.getStorageLevel === StorageLevel.NONE)
+ assert(filteredRdd.partitions.map(_.index) === expectedPartitionIndices)
+ assert(filteredRdd.dependencies.size === 1)
+ assert(filteredRdd.dependencies.head.rdd === mappedRdd)
+ assert(mappedRdd.dependencies.size === 1)
+ assert(mappedRdd.dependencies.head.rdd === parallelRdd)
+ assert(parallelRdd.dependencies.size === 0)
+
+ // Mark the RDD for local checkpointing
+ filteredRdd.localCheckpoint()
+ assert(filteredRdd.checkpointData.isDefined)
+ assert(!filteredRdd.checkpointData.get.isCheckpointed)
+ assert(!filteredRdd.checkpointData.get.checkpointRDD.isDefined)
+ assert(filteredRdd.getStorageLevel === LocalRDDCheckpointData.DEFAULT_STORAGE_LEVEL)
+
+ // After an action, the lineage is truncated
+ val result = filteredRdd.collect()
+ assert(filteredRdd.checkpointData.get.isCheckpointed)
+ assert(filteredRdd.checkpointData.get.checkpointRDD.isDefined)
+ val checkpointRdd = filteredRdd.checkpointData.flatMap(_.checkpointRDD).get
+ assert(filteredRdd.dependencies.size === 1)
+ assert(filteredRdd.dependencies.head.rdd === checkpointRdd)
+ assert(filteredRdd.partitions.map(_.index) === expectedPartitionIndices)
+ assert(checkpointRdd.partitions.map(_.index) === expectedPartitionIndices)
+
+ // Recomputation should yield the same result
+ assert(filteredRdd.collect() === result)
+ assert(filteredRdd.collect() === result)
+ }
+
+ test("basic lineage truncation - caching before checkpointing") {
+ testBasicLineageTruncationWithCaching(
+ newRdd.persist(StorageLevel.MEMORY_ONLY).localCheckpoint(),
+ StorageLevel.MEMORY_AND_DISK)
+ }
+
+ test("basic lineage truncation - caching after checkpointing") {
+ testBasicLineageTruncationWithCaching(
+ newRdd.localCheckpoint().persist(StorageLevel.MEMORY_ONLY),
+ StorageLevel.MEMORY_AND_DISK)
+ }
+
+ test("indirect lineage truncation") {
+ testIndirectLineageTruncation(
+ newRdd.localCheckpoint(),
+ LocalRDDCheckpointData.DEFAULT_STORAGE_LEVEL)
+ }
+
+ test("indirect lineage truncation - caching before checkpointing") {
+ testIndirectLineageTruncation(
+ newRdd.persist(StorageLevel.MEMORY_ONLY).localCheckpoint(),
+ StorageLevel.MEMORY_AND_DISK)
+ }
+
+ test("indirect lineage truncation - caching after checkpointing") {
+ testIndirectLineageTruncation(
+ newRdd.localCheckpoint().persist(StorageLevel.MEMORY_ONLY),
+ StorageLevel.MEMORY_AND_DISK)
+ }
+
+ test("checkpoint without draining iterator") {
+ testWithoutDrainingIterator(
+ newSortedRdd.localCheckpoint(),
+ LocalRDDCheckpointData.DEFAULT_STORAGE_LEVEL,
+ 50)
+ }
+
+ test("checkpoint without draining iterator - caching before checkpointing") {
+ testWithoutDrainingIterator(
+ newSortedRdd.persist(StorageLevel.MEMORY_ONLY).localCheckpoint(),
+ StorageLevel.MEMORY_AND_DISK,
+ 50)
+ }
+
+ test("checkpoint without draining iterator - caching after checkpointing") {
+ testWithoutDrainingIterator(
+ newSortedRdd.localCheckpoint().persist(StorageLevel.MEMORY_ONLY),
+ StorageLevel.MEMORY_AND_DISK,
+ 50)
+ }
+
+ test("checkpoint blocks exist") {
+ testCheckpointBlocksExist(
+ newRdd.localCheckpoint(),
+ LocalRDDCheckpointData.DEFAULT_STORAGE_LEVEL)
+ }
+
+ test("checkpoint blocks exist - caching before checkpointing") {
+ testCheckpointBlocksExist(
+ newRdd.persist(StorageLevel.MEMORY_ONLY).localCheckpoint(),
+ StorageLevel.MEMORY_AND_DISK)
+ }
+
+ test("checkpoint blocks exist - caching after checkpointing") {
+ testCheckpointBlocksExist(
+ newRdd.localCheckpoint().persist(StorageLevel.MEMORY_ONLY),
+ StorageLevel.MEMORY_AND_DISK)
+ }
+
+ test("missing checkpoint block fails with informative message") {
+ val rdd = newRdd.localCheckpoint()
+ val numPartitions = rdd.partitions.size
+ val partitionIndices = rdd.partitions.map(_.index)
+ val bmm = sc.env.blockManager.master
+
+ // After an action, the blocks should be found somewhere in the cache
+ rdd.collect()
+ partitionIndices.foreach { i =>
+ assert(bmm.contains(RDDBlockId(rdd.id, i)))
+ }
+
+ // Remove one of the blocks to simulate executor failure
+ // Collecting the RDD should now fail with an informative exception
+ val blockId = RDDBlockId(rdd.id, numPartitions - 1)
+ bmm.removeBlock(blockId)
+ try {
+ rdd.collect()
+ fail("Collect should have failed if local checkpoint block is removed...")
+ } catch {
+ case se: SparkException =>
+ assert(se.getMessage.contains(s"Checkpoint block $blockId not found"))
+ assert(se.getMessage.contains("rdd.checkpoint()")) // suggest an alternative
+ assert(se.getMessage.contains("fault-tolerant")) // justify the alternative
+ }
+ }
+
+ /**
+ * Helper method to create a simple RDD.
+ */
+ private def newRdd: RDD[Int] = {
+ sc.parallelize(1 to 100, 4)
+ .map { i => i + 1 }
+ .filter { i => i % 2 == 0 }
+ }
+
+ /**
+ * Helper method to create a simple sorted RDD.
+ */
+ private def newSortedRdd: RDD[Int] = newRdd.sortBy(identity)
+
+ /**
+ * Helper method to test basic lineage truncation with caching.
+ *
+ * @param rdd an RDD that is both marked for caching and local checkpointing
+ */
+ private def testBasicLineageTruncationWithCaching[T](
+ rdd: RDD[T],
+ targetStorageLevel: StorageLevel): Unit = {
+ require(targetStorageLevel !== StorageLevel.NONE)
+ require(rdd.getStorageLevel !== StorageLevel.NONE)
+ require(rdd.isLocallyCheckpointed)
+ val result = rdd.collect()
+ assert(rdd.getStorageLevel === targetStorageLevel)
+ assert(rdd.checkpointData.isDefined)
+ assert(rdd.checkpointData.get.isCheckpointed)
+ assert(rdd.checkpointData.get.checkpointRDD.isDefined)
+ assert(rdd.dependencies.head.rdd === rdd.checkpointData.get.checkpointRDD.get)
+ assert(rdd.collect() === result)
+ assert(rdd.collect() === result)
+ }
+
+ /**
+ * Helper method to test indirect lineage truncation.
+ *
+ * Indirect lineage truncation here means the action is called on one of the
+ * checkpointed RDD's descendants, but not on the checkpointed RDD itself.
+ *
+ * @param rdd a locally checkpointed RDD
+ */
+ private def testIndirectLineageTruncation[T](
+ rdd: RDD[T],
+ targetStorageLevel: StorageLevel): Unit = {
+ require(targetStorageLevel !== StorageLevel.NONE)
+ require(rdd.isLocallyCheckpointed)
+ val rdd1 = rdd.map { i => i + "1" }
+ val rdd2 = rdd1.map { i => i + "2" }
+ val rdd3 = rdd2.map { i => i + "3" }
+ val rddDependencies = rdd.dependencies
+ val rdd1Dependencies = rdd1.dependencies
+ val rdd2Dependencies = rdd2.dependencies
+ val rdd3Dependencies = rdd3.dependencies
+ assert(rdd1Dependencies.size === 1)
+ assert(rdd1Dependencies.head.rdd === rdd)
+ assert(rdd2Dependencies.size === 1)
+ assert(rdd2Dependencies.head.rdd === rdd1)
+ assert(rdd3Dependencies.size === 1)
+ assert(rdd3Dependencies.head.rdd === rdd2)
+
+ // Only the locally checkpointed RDD should have special storage level
+ assert(rdd.getStorageLevel === targetStorageLevel)
+ assert(rdd1.getStorageLevel === StorageLevel.NONE)
+ assert(rdd2.getStorageLevel === StorageLevel.NONE)
+ assert(rdd3.getStorageLevel === StorageLevel.NONE)
+
+ // After an action, only the dependencies of the checkpointed RDD changes
+ val result = rdd3.collect()
+ assert(rdd.dependencies !== rddDependencies)
+ assert(rdd1.dependencies === rdd1Dependencies)
+ assert(rdd2.dependencies === rdd2Dependencies)
+ assert(rdd3.dependencies === rdd3Dependencies)
+ assert(rdd3.collect() === result)
+ assert(rdd3.collect() === result)
+ }
+
+ /**
+ * Helper method to test checkpointing without fully draining the iterator.
+ *
+ * Not all RDD actions fully consume the iterator. As a result, a subset of the partitions
+ * may not be cached. However, since we want to truncate the lineage safely, we explicitly
+ * ensure that *all* partitions are fully cached. This method asserts this behavior.
+ *
+ * @param rdd a locally checkpointed RDD
+ */
+ private def testWithoutDrainingIterator[T](
+ rdd: RDD[T],
+ targetStorageLevel: StorageLevel,
+ targetCount: Int): Unit = {
+ require(targetCount > 0)
+ require(targetStorageLevel !== StorageLevel.NONE)
+ require(rdd.isLocallyCheckpointed)
+
+ // This does not drain the iterator, but checkpointing should still work
+ val first = rdd.first()
+ assert(rdd.count() === targetCount)
+ assert(rdd.count() === targetCount)
+ assert(rdd.first() === first)
+ assert(rdd.first() === first)
+
+ // Test the same thing by calling actions on a descendant instead
+ val rdd1 = rdd.repartition(10)
+ val rdd2 = rdd1.repartition(100)
+ val rdd3 = rdd2.repartition(1000)
+ val first2 = rdd3.first()
+ assert(rdd3.count() === targetCount)
+ assert(rdd3.count() === targetCount)
+ assert(rdd3.first() === first2)
+ assert(rdd3.first() === first2)
+ assert(rdd.getStorageLevel === targetStorageLevel)
+ assert(rdd1.getStorageLevel === StorageLevel.NONE)
+ assert(rdd2.getStorageLevel === StorageLevel.NONE)
+ assert(rdd3.getStorageLevel === StorageLevel.NONE)
+ }
+
+ /**
+ * Helper method to test whether the checkpoint blocks are found in the cache.
+ *
+ * @param rdd a locally checkpointed RDD
+ */
+ private def testCheckpointBlocksExist[T](
+ rdd: RDD[T],
+ targetStorageLevel: StorageLevel): Unit = {
+ val bmm = sc.env.blockManager.master
+ val partitionIndices = rdd.partitions.map(_.index)
+
+ // The blocks should not exist before the action
+ partitionIndices.foreach { i =>
+ assert(!bmm.contains(RDDBlockId(rdd.id, i)))
+ }
+
+ // After an action, the blocks should be found in the cache with the expected level
+ rdd.collect()
+ partitionIndices.foreach { i =>
+ val blockId = RDDBlockId(rdd.id, i)
+ val status = bmm.getBlockStatus(blockId)
+ assert(status.nonEmpty)
+ assert(status.values.head.storageLevel === targetStorageLevel)
+ }
+ }
+
+}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index f9384c4c3c..280aac9319 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -80,8 +80,13 @@ object MimaExcludes {
"org.apache.spark.mllib.linalg.Matrix.numActives")
) ++ Seq(
// SPARK-8914 Remove RDDApi
- ProblemFilters.exclude[MissingClassProblem](
- "org.apache.spark.sql.RDDApi")
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RDDApi")
+ ) ++ Seq(
+ // SPARK-7292 Provide operator to truncate lineage cheaply
+ ProblemFilters.exclude[AbstractClassProblem](
+ "org.apache.spark.rdd.RDDCheckpointData"),
+ ProblemFilters.exclude[AbstractClassProblem](
+ "org.apache.spark.rdd.CheckpointRDD")
) ++ Seq(
// SPARK-8701 Add input metadata in the batch page.
ProblemFilters.exclude[MissingClassProblem](