aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/ContextCleaner.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala153
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala67
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala83
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala128
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala106
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala172
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala108
-rw-r--r--core/src/test/scala/org/apache/spark/CheckpointSuite.scala164
-rw-r--r--core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala61
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala330
-rw-r--r--project/MimaExcludes.scala9
14 files changed, 1085 insertions, 315 deletions
diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
index 37198d887b..d23c1533db 100644
--- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
@@ -22,7 +22,7 @@ import java.lang.ref.{ReferenceQueue, WeakReference}
import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.rdd.{RDDCheckpointData, RDD}
+import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData}
import org.apache.spark.util.Utils
/**
@@ -231,11 +231,14 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}
}
- /** Perform checkpoint cleanup. */
+ /**
+ * Clean up checkpoint files written to a reliable storage.
+ * Locally checkpointed files are cleaned up separately through RDD cleanups.
+ */
def doCleanCheckpoint(rddId: Int): Unit = {
try {
logDebug("Cleaning rdd checkpoint data " + rddId)
- RDDCheckpointData.clearRDDCheckpointData(sc, rddId)
+ ReliableRDDCheckpointData.cleanCheckpoint(sc, rddId)
listeners.foreach(_.checkpointCleaned(rddId))
logInfo("Cleaned rdd checkpoint data " + rddId)
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 6f336a7c29..4380cf45cc 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1192,7 +1192,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
protected[spark] def checkpointFile[T: ClassTag](path: String): RDD[T] = withScope {
- new CheckpointRDD[T](this, path)
+ new ReliableCheckpointRDD[T](this, path)
}
/** Build the union of a list of RDDs. */
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index b48836d5c8..5d2c551d58 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -59,6 +59,14 @@ object TaskContext {
* Unset the thread local TaskContext. Internal to Spark.
*/
protected[spark] def unset(): Unit = taskContext.remove()
+
+ /**
+ * Return an empty task context that is not actually used.
+ * Internal use only.
+ */
+ private[spark] def empty(): TaskContext = {
+ new TaskContextImpl(0, 0, 0, 0, null, null)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index e17bd47905..72fe215dae 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -17,156 +17,31 @@
package org.apache.spark.rdd
-import java.io.IOException
-
import scala.reflect.ClassTag
-import org.apache.hadoop.fs.Path
-
-import org.apache.spark._
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.util.{SerializableConfiguration, Utils}
+import org.apache.spark.{Partition, SparkContext, TaskContext}
+/**
+ * An RDD partition used to recover checkpointed data.
+ */
private[spark] class CheckpointRDDPartition(val index: Int) extends Partition
/**
- * This RDD represents a RDD checkpoint file (similar to HadoopRDD).
+ * An RDD that recovers checkpointed data from storage.
*/
-private[spark]
-class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String)
+private[spark] abstract class CheckpointRDD[T: ClassTag](@transient sc: SparkContext)
extends RDD[T](sc, Nil) {
- private val broadcastedConf = sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration))
-
- @transient private val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
-
- override def getCheckpointFile: Option[String] = Some(checkpointPath)
-
- override def getPartitions: Array[Partition] = {
- val cpath = new Path(checkpointPath)
- val numPartitions =
- // listStatus can throw exception if path does not exist.
- if (fs.exists(cpath)) {
- val dirContents = fs.listStatus(cpath).map(_.getPath)
- val partitionFiles = dirContents.filter(_.getName.startsWith("part-")).map(_.toString).sorted
- val numPart = partitionFiles.length
- if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
- ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) {
- throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
- }
- numPart
- } else 0
-
- Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i))
- }
-
- override def getPreferredLocations(split: Partition): Seq[String] = {
- val status = fs.getFileStatus(new Path(checkpointPath,
- CheckpointRDD.splitIdToFile(split.index)))
- val locations = fs.getFileBlockLocations(status, 0, status.getLen)
- locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
- }
-
- override def compute(split: Partition, context: TaskContext): Iterator[T] = {
- val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index))
- CheckpointRDD.readFromFile(file, broadcastedConf, context)
- }
-
// CheckpointRDD should not be checkpointed again
- override def checkpoint(): Unit = { }
override def doCheckpoint(): Unit = { }
-}
-
-private[spark] object CheckpointRDD extends Logging {
- def splitIdToFile(splitId: Int): String = {
- "part-%05d".format(splitId)
- }
-
- def writeToFile[T: ClassTag](
- path: String,
- broadcastedConf: Broadcast[SerializableConfiguration],
- blockSize: Int = -1
- )(ctx: TaskContext, iterator: Iterator[T]) {
- val env = SparkEnv.get
- val outputDir = new Path(path)
- val fs = outputDir.getFileSystem(broadcastedConf.value.value)
-
- val finalOutputName = splitIdToFile(ctx.partitionId)
- val finalOutputPath = new Path(outputDir, finalOutputName)
- val tempOutputPath =
- new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptNumber)
-
- if (fs.exists(tempOutputPath)) {
- throw new IOException("Checkpoint failed: temporary path " +
- tempOutputPath + " already exists")
- }
- val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
-
- val fileOutputStream = if (blockSize < 0) {
- fs.create(tempOutputPath, false, bufferSize)
- } else {
- // This is mainly for testing purpose
- fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize)
- }
- val serializer = env.serializer.newInstance()
- val serializeStream = serializer.serializeStream(fileOutputStream)
- Utils.tryWithSafeFinally {
- serializeStream.writeAll(iterator)
- } {
- serializeStream.close()
- }
-
- if (!fs.rename(tempOutputPath, finalOutputPath)) {
- if (!fs.exists(finalOutputPath)) {
- logInfo("Deleting tempOutputPath " + tempOutputPath)
- fs.delete(tempOutputPath, false)
- throw new IOException("Checkpoint failed: failed to save output of task: "
- + ctx.attemptNumber + " and final output path does not exist")
- } else {
- // Some other copy of this task must've finished before us and renamed it
- logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it")
- fs.delete(tempOutputPath, false)
- }
- }
- }
-
- def readFromFile[T](
- path: Path,
- broadcastedConf: Broadcast[SerializableConfiguration],
- context: TaskContext
- ): Iterator[T] = {
- val env = SparkEnv.get
- val fs = path.getFileSystem(broadcastedConf.value.value)
- val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
- val fileInputStream = fs.open(path, bufferSize)
- val serializer = env.serializer.newInstance()
- val deserializeStream = serializer.deserializeStream(fileInputStream)
-
- // Register an on-task-completion callback to close the input stream.
- context.addTaskCompletionListener(context => deserializeStream.close())
-
- deserializeStream.asIterator.asInstanceOf[Iterator[T]]
- }
+ override def checkpoint(): Unit = { }
+ override def localCheckpoint(): this.type = this
- // Test whether CheckpointRDD generate expected number of partitions despite
- // each split file having multiple blocks. This needs to be run on a
- // cluster (mesos or standalone) using HDFS.
- def main(args: Array[String]) {
- import org.apache.spark._
+ // Note: There is a bug in MiMa that complains about `AbstractMethodProblem`s in the
+ // base [[org.apache.spark.rdd.RDD]] class if we do not override the following methods.
+ // scalastyle:off
+ protected override def getPartitions: Array[Partition] = ???
+ override def compute(p: Partition, tc: TaskContext): Iterator[T] = ???
+ // scalastyle:on
- val Array(cluster, hdfsPath) = args
- val env = SparkEnv.get
- val sc = new SparkContext(cluster, "CheckpointRDD Test")
- val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000)
- val path = new Path(hdfsPath, "temp")
- val conf = SparkHadoopUtil.get.newConfiguration(new SparkConf())
- val fs = path.getFileSystem(conf)
- val broadcastedConf = sc.broadcast(new SerializableConfiguration(conf))
- sc.runJob(rdd, CheckpointRDD.writeToFile[Int](path.toString, broadcastedConf, 1024) _)
- val cpRDD = new CheckpointRDD[Int](sc, path.toString)
- assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same")
- assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same")
- fs.delete(path, true)
- }
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala
new file mode 100644
index 0000000000..daa5779d68
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.{Partition, SparkContext, SparkEnv, SparkException, TaskContext}
+import org.apache.spark.storage.RDDBlockId
+
+/**
+ * A dummy CheckpointRDD that exists to provide informative error messages during failures.
+ *
+ * This is simply a placeholder because the original checkpointed RDD is expected to be
+ * fully cached. Only if an executor fails or if the user explicitly unpersists the original
+ * RDD will Spark ever attempt to compute this CheckpointRDD. When this happens, however,
+ * we must provide an informative error message.
+ *
+ * @param sc the active SparkContext
+ * @param rddId the ID of the checkpointed RDD
+ * @param numPartitions the number of partitions in the checkpointed RDD
+ */
+private[spark] class LocalCheckpointRDD[T: ClassTag](
+ @transient sc: SparkContext,
+ rddId: Int,
+ numPartitions: Int)
+ extends CheckpointRDD[T](sc) {
+
+ def this(rdd: RDD[T]) {
+ this(rdd.context, rdd.id, rdd.partitions.size)
+ }
+
+ protected override def getPartitions: Array[Partition] = {
+ (0 until numPartitions).toArray.map { i => new CheckpointRDDPartition(i) }
+ }
+
+ /**
+ * Throw an exception indicating that the relevant block is not found.
+ *
+ * This should only be called if the original RDD is explicitly unpersisted or if an
+ * executor is lost. Under normal circumstances, however, the original RDD (our child)
+ * is expected to be fully cached and so all partitions should already be computed and
+ * available in the block storage.
+ */
+ override def compute(partition: Partition, context: TaskContext): Iterator[T] = {
+ throw new SparkException(
+ s"Checkpoint block ${RDDBlockId(rddId, partition.index)} not found! Either the executor " +
+ s"that originally checkpointed this partition is no longer alive, or the original RDD is " +
+ s"unpersisted. If this problem persists, you may consider using `rdd.checkpoint()` " +
+ s"instead, which is slower than local checkpointing but more fault-tolerant.")
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala
new file mode 100644
index 0000000000..d6fad89684
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.{Logging, SparkEnv, SparkException, TaskContext}
+import org.apache.spark.storage.{RDDBlockId, StorageLevel}
+import org.apache.spark.util.Utils
+
+/**
+ * An implementation of checkpointing implemented on top of Spark's caching layer.
+ *
+ * Local checkpointing trades off fault tolerance for performance by skipping the expensive
+ * step of saving the RDD data to a reliable and fault-tolerant storage. Instead, the data
+ * is written to the local, ephemeral block storage that lives in each executor. This is useful
+ * for use cases where RDDs build up long lineages that need to be truncated often (e.g. GraphX).
+ */
+private[spark] class LocalRDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
+ extends RDDCheckpointData[T](rdd) with Logging {
+
+ /**
+ * Ensure the RDD is fully cached so the partitions can be recovered later.
+ */
+ protected override def doCheckpoint(): CheckpointRDD[T] = {
+ val level = rdd.getStorageLevel
+
+ // Assume storage level uses disk; otherwise memory eviction may cause data loss
+ assume(level.useDisk, s"Storage level $level is not appropriate for local checkpointing")
+
+ // Not all actions compute all partitions of the RDD (e.g. take). For correctness, we
+ // must cache any missing partitions. TODO: avoid running another job here (SPARK-8582).
+ val action = (tc: TaskContext, iterator: Iterator[T]) => Utils.getIteratorSize(iterator)
+ val missingPartitionIndices = rdd.partitions.map(_.index).filter { i =>
+ !SparkEnv.get.blockManager.master.contains(RDDBlockId(rdd.id, i))
+ }
+ if (missingPartitionIndices.nonEmpty) {
+ rdd.sparkContext.runJob(rdd, action, missingPartitionIndices)
+ }
+
+ new LocalCheckpointRDD[T](rdd)
+ }
+
+}
+
+private[spark] object LocalRDDCheckpointData {
+
+ val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK
+
+ /**
+ * Transform the specified storage level to one that uses disk.
+ *
+ * This guarantees that the RDD can be recomputed multiple times correctly as long as
+ * executors do not fail. Otherwise, if the RDD is cached in memory only, for instance,
+ * the checkpoint data will be lost if the relevant block is evicted from memory.
+ *
+ * This method is idempotent.
+ */
+ def transformStorageLevel(level: StorageLevel): StorageLevel = {
+ // If this RDD is to be cached off-heap, fail fast since we cannot provide any
+ // correctness guarantees about subsequent computations after the first one
+ if (level.useOffHeap) {
+ throw new SparkException("Local checkpointing is not compatible with off-heap caching.")
+ }
+
+ StorageLevel(useDisk = true, level.useMemory, level.deserialized, level.replication)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 6d61d22738..081c721f23 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -149,23 +149,43 @@ abstract class RDD[T: ClassTag](
}
/**
- * Set this RDD's storage level to persist its values across operations after the first time
- * it is computed. This can only be used to assign a new storage level if the RDD does not
- * have a storage level set yet..
+ * Mark this RDD for persisting using the specified level.
+ *
+ * @param newLevel the target storage level
+ * @param allowOverride whether to override any existing level with the new one
*/
- def persist(newLevel: StorageLevel): this.type = {
+ private def persist(newLevel: StorageLevel, allowOverride: Boolean): this.type = {
// TODO: Handle changes of StorageLevel
- if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) {
+ if (storageLevel != StorageLevel.NONE && newLevel != storageLevel && !allowOverride) {
throw new UnsupportedOperationException(
"Cannot change storage level of an RDD after it was already assigned a level")
}
- sc.persistRDD(this)
- // Register the RDD with the ContextCleaner for automatic GC-based cleanup
- sc.cleaner.foreach(_.registerRDDForCleanup(this))
+ // If this is the first time this RDD is marked for persisting, register it
+ // with the SparkContext for cleanups and accounting. Do this only once.
+ if (storageLevel == StorageLevel.NONE) {
+ sc.cleaner.foreach(_.registerRDDForCleanup(this))
+ sc.persistRDD(this)
+ }
storageLevel = newLevel
this
}
+ /**
+ * Set this RDD's storage level to persist its values across operations after the first time
+ * it is computed. This can only be used to assign a new storage level if the RDD does not
+ * have a storage level set yet. Local checkpointing is an exception.
+ */
+ def persist(newLevel: StorageLevel): this.type = {
+ if (isLocallyCheckpointed) {
+ // This means the user previously called localCheckpoint(), which should have already
+ // marked this RDD for persisting. Here we should override the old storage level with
+ // one that is explicitly requested by the user (after adapting it to use disk).
+ persist(LocalRDDCheckpointData.transformStorageLevel(newLevel), allowOverride = true)
+ } else {
+ persist(newLevel, allowOverride = false)
+ }
+ }
+
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def persist(): this.type = persist(StorageLevel.MEMORY_ONLY)
@@ -1448,33 +1468,99 @@ abstract class RDD[T: ClassTag](
/**
* Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint
- * directory set with SparkContext.setCheckpointDir() and all references to its parent
+ * directory set with `SparkContext#setCheckpointDir` and all references to its parent
* RDDs will be removed. This function must be called before any job has been
* executed on this RDD. It is strongly recommended that this RDD is persisted in
* memory, otherwise saving it on a file will require recomputation.
*/
- def checkpoint(): Unit = {
+ def checkpoint(): Unit = RDDCheckpointData.synchronized {
+ // NOTE: we use a global lock here due to complexities downstream with ensuring
+ // children RDD partitions point to the correct parent partitions. In the future
+ // we should revisit this consideration.
if (context.checkpointDir.isEmpty) {
throw new SparkException("Checkpoint directory has not been set in the SparkContext")
} else if (checkpointData.isEmpty) {
- // NOTE: we use a global lock here due to complexities downstream with ensuring
- // children RDD partitions point to the correct parent partitions. In the future
- // we should revisit this consideration.
- RDDCheckpointData.synchronized {
- checkpointData = Some(new RDDCheckpointData(this))
- }
+ checkpointData = Some(new ReliableRDDCheckpointData(this))
+ }
+ }
+
+ /**
+ * Mark this RDD for local checkpointing using Spark's existing caching layer.
+ *
+ * This method is for users who wish to truncate RDD lineages while skipping the expensive
+ * step of replicating the materialized data in a reliable distributed file system. This is
+ * useful for RDDs with long lineages that need to be truncated periodically (e.g. GraphX).
+ *
+ * Local checkpointing sacrifices fault-tolerance for performance. In particular, checkpointed
+ * data is written to ephemeral local storage in the executors instead of to a reliable,
+ * fault-tolerant storage. The effect is that if an executor fails during the computation,
+ * the checkpointed data may no longer be accessible, causing an irrecoverable job failure.
+ *
+ * This is NOT safe to use with dynamic allocation, which removes executors along
+ * with their cached blocks. If you must use both features, you are advised to set
+ * `spark.dynamicAllocation.cachedExecutorIdleTimeout` to a high value.
+ *
+ * The checkpoint directory set through `SparkContext#setCheckpointDir` is not used.
+ */
+ def localCheckpoint(): this.type = RDDCheckpointData.synchronized {
+ if (conf.getBoolean("spark.dynamicAllocation.enabled", false) &&
+ conf.contains("spark.dynamicAllocation.cachedExecutorIdleTimeout")) {
+ logWarning("Local checkpointing is NOT safe to use with dynamic allocation, " +
+ "which removes executors along with their cached blocks. If you must use both " +
+ "features, you are advised to set `spark.dynamicAllocation.cachedExecutorIdleTimeout` " +
+ "to a high value. E.g. If you plan to use the RDD for 1 hour, set the timeout to " +
+ "at least 1 hour.")
+ }
+
+ // Note: At this point we do not actually know whether the user will call persist() on
+ // this RDD later, so we must explicitly call it here ourselves to ensure the cached
+ // blocks are registered for cleanup later in the SparkContext.
+ //
+ // If, however, the user has already called persist() on this RDD, then we must adapt
+ // the storage level he/she specified to one that is appropriate for local checkpointing
+ // (i.e. uses disk) to guarantee correctness.
+
+ if (storageLevel == StorageLevel.NONE) {
+ persist(LocalRDDCheckpointData.DEFAULT_STORAGE_LEVEL)
+ } else {
+ persist(LocalRDDCheckpointData.transformStorageLevel(storageLevel), allowOverride = true)
}
+
+ checkpointData match {
+ case Some(reliable: ReliableRDDCheckpointData[_]) => logWarning(
+ "RDD was already marked for reliable checkpointing: overriding with local checkpoint.")
+ case _ =>
+ }
+ checkpointData = Some(new LocalRDDCheckpointData(this))
+ this
}
/**
- * Return whether this RDD has been checkpointed or not
+ * Return whether this RDD is marked for checkpointing, either reliably or locally.
*/
def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed)
/**
- * Gets the name of the file to which this RDD was checkpointed
+ * Return whether this RDD is marked for local checkpointing.
+ * Exposed for testing.
*/
- def getCheckpointFile: Option[String] = checkpointData.flatMap(_.getCheckpointFile)
+ private[rdd] def isLocallyCheckpointed: Boolean = {
+ checkpointData match {
+ case Some(_: LocalRDDCheckpointData[T]) => true
+ case _ => false
+ }
+ }
+
+ /**
+ * Gets the name of the directory to which this RDD was checkpointed.
+ * This is not defined if the RDD is checkpointed locally.
+ */
+ def getCheckpointFile: Option[String] = {
+ checkpointData match {
+ case Some(reliable: ReliableRDDCheckpointData[T]) => reliable.getCheckpointDir
+ case _ => None
+ }
+ }
// =======================================================================
// Other internal methods and fields
@@ -1545,7 +1631,7 @@ abstract class RDD[T: ClassTag](
if (!doCheckpointCalled) {
doCheckpointCalled = true
if (checkpointData.isDefined) {
- checkpointData.get.doCheckpoint()
+ checkpointData.get.checkpoint()
} else {
dependencies.foreach(_.rdd.doCheckpoint())
}
@@ -1557,7 +1643,7 @@ abstract class RDD[T: ClassTag](
* Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`)
* created from the checkpoint file, and forget its old dependencies and partitions.
*/
- private[spark] def markCheckpointed(checkpointRDD: RDD[_]) {
+ private[spark] def markCheckpointed(): Unit = {
clearDependencies()
partitions_ = null
deps = null // Forget the constructor argument for dependencies too
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
index 4f954363be..0e43520870 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
@@ -19,10 +19,7 @@ package org.apache.spark.rdd
import scala.reflect.ClassTag
-import org.apache.hadoop.fs.Path
-
-import org.apache.spark._
-import org.apache.spark.util.SerializableConfiguration
+import org.apache.spark.Partition
/**
* Enumeration to manage state transitions of an RDD through checkpointing
@@ -39,39 +36,31 @@ private[spark] object CheckpointState extends Enumeration {
* as well as, manages the post-checkpoint state by providing the updated partitions,
* iterator and preferred locations of the checkpointed RDD.
*/
-private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
- extends Logging with Serializable {
+private[spark] abstract class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
+ extends Serializable {
import CheckpointState._
// The checkpoint state of the associated RDD.
- private var cpState = Initialized
-
- // The file to which the associated RDD has been checkpointed to
- private var cpFile: Option[String] = None
+ protected var cpState = Initialized
- // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD.
- // This is defined if and only if `cpState` is `Checkpointed`.
+ // The RDD that contains our checkpointed data
private var cpRDD: Option[CheckpointRDD[T]] = None
// TODO: are we sure we need to use a global lock in the following methods?
- // Is the RDD already checkpointed
+ /**
+ * Return whether the checkpoint data for this RDD is already persisted.
+ */
def isCheckpointed: Boolean = RDDCheckpointData.synchronized {
cpState == Checkpointed
}
- // Get the file to which this RDD was checkpointed to as an Option
- def getCheckpointFile: Option[String] = RDDCheckpointData.synchronized {
- cpFile
- }
-
/**
- * Materialize this RDD and write its content to a reliable DFS.
+ * Materialize this RDD and persist its content.
* This is called immediately after the first action invoked on this RDD has completed.
*/
- def doCheckpoint(): Unit = {
-
+ final def checkpoint(): Unit = {
// Guard against multiple threads checkpointing the same RDD by
// atomically flipping the state of this RDDCheckpointData
RDDCheckpointData.synchronized {
@@ -82,64 +71,41 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
}
}
- // Create the output path for the checkpoint
- val path = RDDCheckpointData.rddCheckpointDataPath(rdd.context, rdd.id).get
- val fs = path.getFileSystem(rdd.context.hadoopConfiguration)
- if (!fs.mkdirs(path)) {
- throw new SparkException(s"Failed to create checkpoint path $path")
- }
-
- // Save to file, and reload it as an RDD
- val broadcastedConf = rdd.context.broadcast(
- new SerializableConfiguration(rdd.context.hadoopConfiguration))
- val newRDD = new CheckpointRDD[T](rdd.context, path.toString)
- if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) {
- rdd.context.cleaner.foreach { cleaner =>
- cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id)
- }
- }
-
- // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
- rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _)
- if (newRDD.partitions.length != rdd.partitions.length) {
- throw new SparkException(
- "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.length + ") has different " +
- "number of partitions than original RDD " + rdd + "(" + rdd.partitions.length + ")")
- }
+ val newRDD = doCheckpoint()
- // Change the dependencies and partitions of the RDD
+ // Update our state and truncate the RDD lineage
RDDCheckpointData.synchronized {
- cpFile = Some(path.toString)
cpRDD = Some(newRDD)
- rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
cpState = Checkpointed
+ rdd.markCheckpointed()
}
- logInfo(s"Done checkpointing RDD ${rdd.id} to $path, new parent is RDD ${newRDD.id}")
- }
-
- def getPartitions: Array[Partition] = RDDCheckpointData.synchronized {
- cpRDD.get.partitions
}
- def checkpointRDD: Option[CheckpointRDD[T]] = RDDCheckpointData.synchronized {
- cpRDD
- }
-}
+ /**
+ * Materialize this RDD and persist its content.
+ *
+ * Subclasses should override this method to define custom checkpointing behavior.
+ * @return the checkpoint RDD created in the process.
+ */
+ protected def doCheckpoint(): CheckpointRDD[T]
-private[spark] object RDDCheckpointData {
+ /**
+ * Return the RDD that contains our checkpointed data.
+ * This is only defined if the checkpoint state is `Checkpointed`.
+ */
+ def checkpointRDD: Option[CheckpointRDD[T]] = RDDCheckpointData.synchronized { cpRDD }
- /** Return the path of the directory to which this RDD's checkpoint data is written. */
- def rddCheckpointDataPath(sc: SparkContext, rddId: Int): Option[Path] = {
- sc.checkpointDir.map { dir => new Path(dir, s"rdd-$rddId") }
+ /**
+ * Return the partitions of the resulting checkpoint RDD.
+ * For tests only.
+ */
+ def getPartitions: Array[Partition] = RDDCheckpointData.synchronized {
+ cpRDD.map(_.partitions).getOrElse { Array.empty }
}
- /** Clean up the files associated with the checkpoint data for this RDD. */
- def clearRDDCheckpointData(sc: SparkContext, rddId: Int): Unit = {
- rddCheckpointDataPath(sc, rddId).foreach { path =>
- val fs = path.getFileSystem(sc.hadoopConfiguration)
- if (fs.exists(path)) {
- fs.delete(path, true)
- }
- }
- }
}
+
+/**
+ * Global lock for synchronizing checkpoint operations.
+ */
+private[spark] object RDDCheckpointData
diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
new file mode 100644
index 0000000000..35d8b0bfd1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import java.io.IOException
+
+import scala.reflect.ClassTag
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark._
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.util.{SerializableConfiguration, Utils}
+
+/**
+ * An RDD that reads from checkpoint files previously written to reliable storage.
+ */
+private[spark] class ReliableCheckpointRDD[T: ClassTag](
+ @transient sc: SparkContext,
+ val checkpointPath: String)
+ extends CheckpointRDD[T](sc) {
+
+ @transient private val hadoopConf = sc.hadoopConfiguration
+ @transient private val cpath = new Path(checkpointPath)
+ @transient private val fs = cpath.getFileSystem(hadoopConf)
+ private val broadcastedConf = sc.broadcast(new SerializableConfiguration(hadoopConf))
+
+ // Fail fast if checkpoint directory does not exist
+ require(fs.exists(cpath), s"Checkpoint directory does not exist: $checkpointPath")
+
+ /**
+ * Return the path of the checkpoint directory this RDD reads data from.
+ */
+ override def getCheckpointFile: Option[String] = Some(checkpointPath)
+
+ /**
+ * Return partitions described by the files in the checkpoint directory.
+ *
+ * Since the original RDD may belong to a prior application, there is no way to know a
+ * priori the number of partitions to expect. This method assumes that the original set of
+ * checkpoint files are fully preserved in a reliable storage across application lifespans.
+ */
+ protected override def getPartitions: Array[Partition] = {
+ // listStatus can throw exception if path does not exist.
+ val inputFiles = fs.listStatus(cpath)
+ .map(_.getPath)
+ .filter(_.getName.startsWith("part-"))
+ .sortBy(_.toString)
+ // Fail fast if input files are invalid
+ inputFiles.zipWithIndex.foreach { case (path, i) =>
+ if (!path.toString.endsWith(ReliableCheckpointRDD.checkpointFileName(i))) {
+ throw new SparkException(s"Invalid checkpoint file: $path")
+ }
+ }
+ Array.tabulate(inputFiles.length)(i => new CheckpointRDDPartition(i))
+ }
+
+ /**
+ * Return the locations of the checkpoint file associated with the given partition.
+ */
+ protected override def getPreferredLocations(split: Partition): Seq[String] = {
+ val status = fs.getFileStatus(
+ new Path(checkpointPath, ReliableCheckpointRDD.checkpointFileName(split.index)))
+ val locations = fs.getFileBlockLocations(status, 0, status.getLen)
+ locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
+ }
+
+ /**
+ * Read the content of the checkpoint file associated with the given partition.
+ */
+ override def compute(split: Partition, context: TaskContext): Iterator[T] = {
+ val file = new Path(checkpointPath, ReliableCheckpointRDD.checkpointFileName(split.index))
+ ReliableCheckpointRDD.readCheckpointFile(file, broadcastedConf, context)
+ }
+
+}
+
+private[spark] object ReliableCheckpointRDD extends Logging {
+
+ /**
+ * Return the checkpoint file name for the given partition.
+ */
+ private def checkpointFileName(partitionIndex: Int): String = {
+ "part-%05d".format(partitionIndex)
+ }
+
+ /**
+ * Write this partition's values to a checkpoint file.
+ */
+ def writeCheckpointFile[T: ClassTag](
+ path: String,
+ broadcastedConf: Broadcast[SerializableConfiguration],
+ blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
+ val env = SparkEnv.get
+ val outputDir = new Path(path)
+ val fs = outputDir.getFileSystem(broadcastedConf.value.value)
+
+ val finalOutputName = ReliableCheckpointRDD.checkpointFileName(ctx.partitionId())
+ val finalOutputPath = new Path(outputDir, finalOutputName)
+ val tempOutputPath =
+ new Path(outputDir, s".$finalOutputName-attempt-${ctx.attemptNumber()}")
+
+ if (fs.exists(tempOutputPath)) {
+ throw new IOException(s"Checkpoint failed: temporary path $tempOutputPath already exists")
+ }
+ val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
+
+ val fileOutputStream = if (blockSize < 0) {
+ fs.create(tempOutputPath, false, bufferSize)
+ } else {
+ // This is mainly for testing purpose
+ fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize)
+ }
+ val serializer = env.serializer.newInstance()
+ val serializeStream = serializer.serializeStream(fileOutputStream)
+ Utils.tryWithSafeFinally {
+ serializeStream.writeAll(iterator)
+ } {
+ serializeStream.close()
+ }
+
+ if (!fs.rename(tempOutputPath, finalOutputPath)) {
+ if (!fs.exists(finalOutputPath)) {
+ logInfo(s"Deleting tempOutputPath $tempOutputPath")
+ fs.delete(tempOutputPath, false)
+ throw new IOException("Checkpoint failed: failed to save output of task: " +
+ s"${ctx.attemptNumber()} and final output path does not exist: $finalOutputPath")
+ } else {
+ // Some other copy of this task must've finished before us and renamed it
+ logInfo(s"Final output path $finalOutputPath already exists; not overwriting it")
+ fs.delete(tempOutputPath, false)
+ }
+ }
+ }
+
+ /**
+ * Read the content of the specified checkpoint file.
+ */
+ def readCheckpointFile[T](
+ path: Path,
+ broadcastedConf: Broadcast[SerializableConfiguration],
+ context: TaskContext): Iterator[T] = {
+ val env = SparkEnv.get
+ val fs = path.getFileSystem(broadcastedConf.value.value)
+ val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
+ val fileInputStream = fs.open(path, bufferSize)
+ val serializer = env.serializer.newInstance()
+ val deserializeStream = serializer.deserializeStream(fileInputStream)
+
+ // Register an on-task-completion callback to close the input stream.
+ context.addTaskCompletionListener(context => deserializeStream.close())
+
+ deserializeStream.asIterator.asInstanceOf[Iterator[T]]
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala
new file mode 100644
index 0000000000..1df8eef5ff
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.reflect.ClassTag
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark._
+import org.apache.spark.util.SerializableConfiguration
+
+/**
+ * An implementation of checkpointing that writes the RDD data to reliable storage.
+ * This allows drivers to be restarted on failure with previously computed state.
+ */
+private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
+ extends RDDCheckpointData[T](rdd) with Logging {
+
+ // The directory to which the associated RDD has been checkpointed to
+ // This is assumed to be a non-local path that points to some reliable storage
+ private val cpDir: String =
+ ReliableRDDCheckpointData.checkpointPath(rdd.context, rdd.id)
+ .map(_.toString)
+ .getOrElse { throw new SparkException("Checkpoint dir must be specified.") }
+
+ /**
+ * Return the directory to which this RDD was checkpointed.
+ * If the RDD is not checkpointed yet, return None.
+ */
+ def getCheckpointDir: Option[String] = RDDCheckpointData.synchronized {
+ if (isCheckpointed) {
+ Some(cpDir.toString)
+ } else {
+ None
+ }
+ }
+
+ /**
+ * Materialize this RDD and write its content to a reliable DFS.
+ * This is called immediately after the first action invoked on this RDD has completed.
+ */
+ protected override def doCheckpoint(): CheckpointRDD[T] = {
+
+ // Create the output path for the checkpoint
+ val path = new Path(cpDir)
+ val fs = path.getFileSystem(rdd.context.hadoopConfiguration)
+ if (!fs.mkdirs(path)) {
+ throw new SparkException(s"Failed to create checkpoint path $cpDir")
+ }
+
+ // Save to file, and reload it as an RDD
+ val broadcastedConf = rdd.context.broadcast(
+ new SerializableConfiguration(rdd.context.hadoopConfiguration))
+ // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
+ rdd.context.runJob(rdd, ReliableCheckpointRDD.writeCheckpointFile[T](cpDir, broadcastedConf) _)
+ val newRDD = new ReliableCheckpointRDD[T](rdd.context, cpDir)
+ if (newRDD.partitions.length != rdd.partitions.length) {
+ throw new SparkException(
+ s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " +
+ s"number of partitions from original RDD $rdd(${rdd.partitions.length})")
+ }
+
+ // Optionally clean our checkpoint files if the reference is out of scope
+ if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) {
+ rdd.context.cleaner.foreach { cleaner =>
+ cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id)
+ }
+ }
+
+ logInfo(s"Done checkpointing RDD ${rdd.id} to $cpDir, new parent is RDD ${newRDD.id}")
+
+ newRDD
+ }
+
+}
+
+private[spark] object ReliableRDDCheckpointData {
+
+ /** Return the path of the directory to which this RDD's checkpoint data is written. */
+ def checkpointPath(sc: SparkContext, rddId: Int): Option[Path] = {
+ sc.checkpointDir.map { dir => new Path(dir, s"rdd-$rddId") }
+ }
+
+ /** Clean up the files associated with the checkpoint data for this RDD. */
+ def cleanCheckpoint(sc: SparkContext, rddId: Int): Unit = {
+ checkpointPath(sc, rddId).foreach { path =>
+ val fs = path.getFileSystem(sc.hadoopConfiguration)
+ if (fs.exists(path)) {
+ fs.delete(path, true)
+ }
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index cc50e6d79a..d343bb95cb 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -25,11 +25,15 @@ import org.apache.spark.rdd._
import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}
import org.apache.spark.util.Utils
+/**
+ * Test suite for end-to-end checkpointing functionality.
+ * This tests both reliable checkpoints and local checkpoints.
+ */
class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging {
- var checkpointDir: File = _
- val partitioner = new HashPartitioner(2)
+ private var checkpointDir: File = _
+ private val partitioner = new HashPartitioner(2)
- override def beforeEach() {
+ override def beforeEach(): Unit = {
super.beforeEach()
checkpointDir = File.createTempFile("temp", "", Utils.createTempDir())
checkpointDir.delete()
@@ -37,40 +41,43 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
sc.setCheckpointDir(checkpointDir.toString)
}
- override def afterEach() {
+ override def afterEach(): Unit = {
super.afterEach()
Utils.deleteRecursively(checkpointDir)
}
- test("basic checkpointing") {
+ runTest("basic checkpointing") { reliableCheckpoint: Boolean =>
val parCollection = sc.makeRDD(1 to 4)
val flatMappedRDD = parCollection.flatMap(x => 1 to x)
- flatMappedRDD.checkpoint()
+ checkpoint(flatMappedRDD, reliableCheckpoint)
assert(flatMappedRDD.dependencies.head.rdd === parCollection)
val result = flatMappedRDD.collect()
assert(flatMappedRDD.dependencies.head.rdd != parCollection)
assert(flatMappedRDD.collect() === result)
}
- test("RDDs with one-to-one dependencies") {
- testRDD(_.map(x => x.toString))
- testRDD(_.flatMap(x => 1 to x))
- testRDD(_.filter(_ % 2 == 0))
- testRDD(_.sample(false, 0.5, 0))
- testRDD(_.glom())
- testRDD(_.mapPartitions(_.map(_.toString)))
- testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
- testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
- testRDD(_.pipe(Seq("cat")))
+ runTest("RDDs with one-to-one dependencies") { reliableCheckpoint: Boolean =>
+ testRDD(_.map(x => x.toString), reliableCheckpoint)
+ testRDD(_.flatMap(x => 1 to x), reliableCheckpoint)
+ testRDD(_.filter(_ % 2 == 0), reliableCheckpoint)
+ testRDD(_.sample(false, 0.5, 0), reliableCheckpoint)
+ testRDD(_.glom(), reliableCheckpoint)
+ testRDD(_.mapPartitions(_.map(_.toString)), reliableCheckpoint)
+ testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString), reliableCheckpoint)
+ testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x),
+ reliableCheckpoint)
+ testRDD(_.pipe(Seq("cat")), reliableCheckpoint)
}
- test("ParallelCollection") {
+ runTest("ParallelCollectionRDD") { reliableCheckpoint: Boolean =>
val parCollection = sc.makeRDD(1 to 4, 2)
val numPartitions = parCollection.partitions.size
- parCollection.checkpoint()
+ checkpoint(parCollection, reliableCheckpoint)
assert(parCollection.dependencies === Nil)
val result = parCollection.collect()
- assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result)
+ if (reliableCheckpoint) {
+ assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result)
+ }
assert(parCollection.dependencies != Nil)
assert(parCollection.partitions.length === numPartitions)
assert(parCollection.partitions.toList ===
@@ -78,44 +85,46 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
assert(parCollection.collect() === result)
}
- test("BlockRDD") {
+ runTest("BlockRDD") { reliableCheckpoint: Boolean =>
val blockId = TestBlockId("id")
val blockManager = SparkEnv.get.blockManager
blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY)
val blockRDD = new BlockRDD[String](sc, Array(blockId))
val numPartitions = blockRDD.partitions.size
- blockRDD.checkpoint()
+ checkpoint(blockRDD, reliableCheckpoint)
val result = blockRDD.collect()
- assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result)
+ if (reliableCheckpoint) {
+ assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result)
+ }
assert(blockRDD.dependencies != Nil)
assert(blockRDD.partitions.length === numPartitions)
assert(blockRDD.partitions.toList === blockRDD.checkpointData.get.getPartitions.toList)
assert(blockRDD.collect() === result)
}
- test("ShuffledRDD") {
+ runTest("ShuffleRDD") { reliableCheckpoint: Boolean =>
testRDD(rdd => {
// Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD
new ShuffledRDD[Int, Int, Int](rdd.map(x => (x % 2, 1)), partitioner)
- })
+ }, reliableCheckpoint)
}
- test("UnionRDD") {
+ runTest("UnionRDD") { reliableCheckpoint: Boolean =>
def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1)
- testRDD(_.union(otherRDD))
- testRDDPartitions(_.union(otherRDD))
+ testRDD(_.union(otherRDD), reliableCheckpoint)
+ testRDDPartitions(_.union(otherRDD), reliableCheckpoint)
}
- test("CartesianRDD") {
+ runTest("CartesianRDD") { reliableCheckpoint: Boolean =>
def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1)
- testRDD(new CartesianRDD(sc, _, otherRDD))
- testRDDPartitions(new CartesianRDD(sc, _, otherRDD))
+ testRDD(new CartesianRDD(sc, _, otherRDD), reliableCheckpoint)
+ testRDDPartitions(new CartesianRDD(sc, _, otherRDD), reliableCheckpoint)
// Test that the CartesianRDD updates parent partitions (CartesianRDD.s1/s2) after
// the parent RDD has been checkpointed and parent partitions have been changed.
// Note that this test is very specific to the current implementation of CartesianRDD.
val ones = sc.makeRDD(1 to 100, 10).map(x => x)
- ones.checkpoint() // checkpoint that MappedRDD
+ checkpoint(ones, reliableCheckpoint) // checkpoint that MappedRDD
val cartesian = new CartesianRDD(sc, ones, ones)
val splitBeforeCheckpoint =
serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition])
@@ -129,16 +138,16 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
)
}
- test("CoalescedRDD") {
- testRDD(_.coalesce(2))
- testRDDPartitions(_.coalesce(2))
+ runTest("CoalescedRDD") { reliableCheckpoint: Boolean =>
+ testRDD(_.coalesce(2), reliableCheckpoint)
+ testRDDPartitions(_.coalesce(2), reliableCheckpoint)
// Test that the CoalescedRDDPartition updates parent partitions (CoalescedRDDPartition.parents)
// after the parent RDD has been checkpointed and parent partitions have been changed.
// Note that this test is very specific to the current implementation of
// CoalescedRDDPartitions.
val ones = sc.makeRDD(1 to 100, 10).map(x => x)
- ones.checkpoint() // checkpoint that MappedRDD
+ checkpoint(ones, reliableCheckpoint) // checkpoint that MappedRDD
val coalesced = new CoalescedRDD(ones, 2)
val splitBeforeCheckpoint =
serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition])
@@ -151,7 +160,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
)
}
- test("CoGroupedRDD") {
+ runTest("CoGroupedRDD") { reliableCheckpoint: Boolean =>
val longLineageRDD1 = generateFatPairRDD()
// Collect the RDD as sequences instead of arrays to enable equality tests in testRDD
@@ -160,26 +169,26 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
testRDD(rdd => {
CheckpointSuite.cogroup(longLineageRDD1, rdd.map(x => (x % 2, 1)), partitioner)
- }, seqCollectFunc)
+ }, reliableCheckpoint, seqCollectFunc)
val longLineageRDD2 = generateFatPairRDD()
testRDDPartitions(rdd => {
CheckpointSuite.cogroup(
longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner)
- }, seqCollectFunc)
+ }, reliableCheckpoint, seqCollectFunc)
}
- test("ZippedPartitionsRDD") {
- testRDD(rdd => rdd.zip(rdd.map(x => x)))
- testRDDPartitions(rdd => rdd.zip(rdd.map(x => x)))
+ runTest("ZippedPartitionsRDD") { reliableCheckpoint: Boolean =>
+ testRDD(rdd => rdd.zip(rdd.map(x => x)), reliableCheckpoint)
+ testRDDPartitions(rdd => rdd.zip(rdd.map(x => x)), reliableCheckpoint)
// Test that ZippedPartitionsRDD updates parent partitions after parent RDDs have
// been checkpointed and parent partitions have been changed.
// Note that this test is very specific to the implementation of ZippedPartitionsRDD.
val rdd = generateFatRDD()
val zippedRDD = rdd.zip(rdd.map(x => x)).asInstanceOf[ZippedPartitionsRDD2[_, _, _]]
- zippedRDD.rdd1.checkpoint()
- zippedRDD.rdd2.checkpoint()
+ checkpoint(zippedRDD.rdd1, reliableCheckpoint)
+ checkpoint(zippedRDD.rdd2, reliableCheckpoint)
val partitionBeforeCheckpoint =
serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartitionsPartition])
zippedRDD.count()
@@ -194,27 +203,27 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
)
}
- test("PartitionerAwareUnionRDD") {
+ runTest("PartitionerAwareUnionRDD") { reliableCheckpoint: Boolean =>
testRDD(rdd => {
new PartitionerAwareUnionRDD[(Int, Int)](sc, Array(
generateFatPairRDD(),
rdd.map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _)
))
- })
+ }, reliableCheckpoint)
testRDDPartitions(rdd => {
new PartitionerAwareUnionRDD[(Int, Int)](sc, Array(
generateFatPairRDD(),
rdd.map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _)
))
- })
+ }, reliableCheckpoint)
// Test that the PartitionerAwareUnionRDD updates parent partitions
// (PartitionerAwareUnionRDD.parents) after the parent RDD has been checkpointed and parent
// partitions have been changed. Note that this test is very specific to the current
// implementation of PartitionerAwareUnionRDD.
val pairRDD = generateFatPairRDD()
- pairRDD.checkpoint()
+ checkpoint(pairRDD, reliableCheckpoint)
val unionRDD = new PartitionerAwareUnionRDD(sc, Array(pairRDD))
val partitionBeforeCheckpoint = serializeDeserialize(
unionRDD.partitions.head.asInstanceOf[PartitionerAwareUnionRDDPartition])
@@ -228,17 +237,34 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
)
}
- test("CheckpointRDD with zero partitions") {
+ runTest("CheckpointRDD with zero partitions") { reliableCheckpoint: Boolean =>
val rdd = new BlockRDD[Int](sc, Array[BlockId]())
assert(rdd.partitions.size === 0)
assert(rdd.isCheckpointed === false)
- rdd.checkpoint()
+ checkpoint(rdd, reliableCheckpoint)
assert(rdd.count() === 0)
assert(rdd.isCheckpointed === true)
assert(rdd.partitions.size === 0)
}
- def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect()
+ // Utility test methods
+
+ /** Checkpoint the RDD either locally or reliably. */
+ private def checkpoint(rdd: RDD[_], reliableCheckpoint: Boolean): Unit = {
+ if (reliableCheckpoint) {
+ rdd.checkpoint()
+ } else {
+ rdd.localCheckpoint()
+ }
+ }
+
+ /** Run a test twice, once for local checkpointing and once for reliable checkpointing. */
+ private def runTest(name: String)(body: Boolean => Unit): Unit = {
+ test(name + " [reliable checkpoint]")(body(true))
+ test(name + " [local checkpoint]")(body(false))
+ }
+
+ private def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect()
/**
* Test checkpointing of the RDD generated by the given operation. It tests whether the
@@ -246,11 +272,14 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
* on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.).
*
* @param op an operation to run on the RDD
+ * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints
* @param collectFunc a function for collecting the values in the RDD, in case there are
* non-comparable types like arrays that we want to convert to something that supports ==
*/
- def testRDD[U: ClassTag](op: (RDD[Int]) => RDD[U],
- collectFunc: RDD[U] => Any = defaultCollectFunc[U] _) {
+ private def testRDD[U: ClassTag](
+ op: (RDD[Int]) => RDD[U],
+ reliableCheckpoint: Boolean,
+ collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = {
// Generate the final RDD using given RDD operation
val baseRDD = generateFatRDD()
val operatedRDD = op(baseRDD)
@@ -267,14 +296,16 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
// Find serialized sizes before and after the checkpoint
logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
- operatedRDD.checkpoint()
+ checkpoint(operatedRDD, reliableCheckpoint)
val result = collectFunc(operatedRDD)
operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables
val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
// Test whether the checkpoint file has been created
- assert(collectFunc(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result)
+ if (reliableCheckpoint) {
+ assert(collectFunc(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result)
+ }
// Test whether dependencies have been changed from its earlier parent RDD
assert(operatedRDD.dependencies.head.rdd != parentRDD)
@@ -310,11 +341,14 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
* partitions (i.e., do not call it on simple RDD like MappedRDD).
*
* @param op an operation to run on the RDD
+ * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints
* @param collectFunc a function for collecting the values in the RDD, in case there are
* non-comparable types like arrays that we want to convert to something that supports ==
*/
- def testRDDPartitions[U: ClassTag](op: (RDD[Int]) => RDD[U],
- collectFunc: RDD[U] => Any = defaultCollectFunc[U] _) {
+ private def testRDDPartitions[U: ClassTag](
+ op: (RDD[Int]) => RDD[U],
+ reliableCheckpoint: Boolean,
+ collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = {
// Generate the final RDD using given RDD operation
val baseRDD = generateFatRDD()
val operatedRDD = op(baseRDD)
@@ -328,7 +362,10 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
// Find serialized sizes before and after the checkpoint
logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
- parentRDDs.foreach(_.checkpoint()) // checkpoint the parent RDD, not the generated one
+ // checkpoint the parent RDD, not the generated one
+ parentRDDs.foreach { rdd =>
+ checkpoint(rdd, reliableCheckpoint)
+ }
val result = collectFunc(operatedRDD) // force checkpointing
operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables
val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
@@ -350,7 +387,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
/**
* Generate an RDD such that both the RDD and its partitions have large size.
*/
- def generateFatRDD(): RDD[Int] = {
+ private def generateFatRDD(): RDD[Int] = {
new FatRDD(sc.makeRDD(1 to 100, 4)).map(x => x)
}
@@ -358,7 +395,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
* Generate an pair RDD (with partitioner) such that both the RDD and its partitions
* have large size.
*/
- def generateFatPairRDD(): RDD[(Int, Int)] = {
+ private def generateFatPairRDD(): RDD[(Int, Int)] = {
new FatPairRDD(sc.makeRDD(1 to 100, 4), partitioner).mapValues(x => x)
}
@@ -366,7 +403,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
* Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks
* upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint.
*/
- def getSerializedSizes(rdd: RDD[_]): (Int, Int) = {
+ private def getSerializedSizes(rdd: RDD[_]): (Int, Int) = {
val rddSize = Utils.serialize(rdd).size
val rddCpDataSize = Utils.serialize(rdd.checkpointData).size
val rddPartitionSize = Utils.serialize(rdd.partitions).size
@@ -394,7 +431,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
* contents after deserialization (e.g., the contents of an RDD split after
* it is sent to a slave along with a task)
*/
- def serializeDeserialize[T](obj: T): T = {
+ private def serializeDeserialize[T](obj: T): T = {
val bytes = Utils.serialize(obj)
Utils.deserialize[T](bytes)
}
@@ -402,10 +439,11 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
/**
* Recursively force the initialization of the all members of an RDD and it parents.
*/
- def initializeRdd(rdd: RDD[_]) {
+ private def initializeRdd(rdd: RDD[_]): Unit = {
rdd.partitions // forces the
- rdd.dependencies.map(_.rdd).foreach(initializeRdd(_))
+ rdd.dependencies.map(_.rdd).foreach(initializeRdd)
}
+
}
/** RDD partition that has large serialized size. */
diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
index 26858ef277..0c14bef7be 100644
--- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
@@ -24,12 +24,11 @@ import scala.language.existentials
import scala.util.Random
import org.scalatest.BeforeAndAfter
-import org.scalatest.concurrent.{PatienceConfiguration, Eventually}
+import org.scalatest.concurrent.PatienceConfiguration
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
-import org.apache.spark.SparkContext._
-import org.apache.spark.rdd.{RDDCheckpointData, RDD}
+import org.apache.spark.rdd.{ReliableRDDCheckpointData, RDD}
import org.apache.spark.storage._
import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.shuffle.sort.SortShuffleManager
@@ -52,6 +51,7 @@ abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[Ha
.setAppName("ContextCleanerSuite")
.set("spark.cleaner.referenceTracking.blocking", "true")
.set("spark.cleaner.referenceTracking.blocking.shuffle", "true")
+ .set("spark.cleaner.referenceTracking.cleanCheckpoints", "true")
.set("spark.shuffle.manager", shuffleManager.getName)
before {
@@ -209,11 +209,11 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase {
postGCTester.assertCleanup()
}
- test("automatically cleanup checkpoint") {
+ test("automatically cleanup normal checkpoint") {
val checkpointDir = java.io.File.createTempFile("temp", "")
checkpointDir.deleteOnExit()
checkpointDir.delete()
- var rdd = newPairRDD
+ var rdd = newPairRDD()
sc.setCheckpointDir(checkpointDir.toString)
rdd.checkpoint()
rdd.cache()
@@ -221,23 +221,26 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase {
var rddId = rdd.id
// Confirm the checkpoint directory exists
- assert(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).isDefined)
- val path = RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get
+ assert(ReliableRDDCheckpointData.checkpointPath(sc, rddId).isDefined)
+ val path = ReliableRDDCheckpointData.checkpointPath(sc, rddId).get
val fs = path.getFileSystem(sc.hadoopConfiguration)
assert(fs.exists(path))
// the checkpoint is not cleaned by default (without the configuration set)
- var postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, Nil)
+ var postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, Seq(rddId))
rdd = null // Make RDD out of scope, ok if collected earlier
runGC()
postGCTester.assertCleanup()
- assert(fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get))
+ assert(!fs.exists(ReliableRDDCheckpointData.checkpointPath(sc, rddId).get))
+ // Verify that checkpoints are NOT cleaned up if the config is not enabled
sc.stop()
- val conf = new SparkConf().setMaster("local[2]").setAppName("cleanupCheckpoint").
- set("spark.cleaner.referenceTracking.cleanCheckpoints", "true")
+ val conf = new SparkConf()
+ .setMaster("local[2]")
+ .setAppName("cleanupCheckpoint")
+ .set("spark.cleaner.referenceTracking.cleanCheckpoints", "false")
sc = new SparkContext(conf)
- rdd = newPairRDD
+ rdd = newPairRDD()
sc.setCheckpointDir(checkpointDir.toString)
rdd.checkpoint()
rdd.cache()
@@ -245,17 +248,40 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase {
rddId = rdd.id
// Confirm the checkpoint directory exists
- assert(fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get))
+ assert(fs.exists(ReliableRDDCheckpointData.checkpointPath(sc, rddId).get))
// Reference rdd to defeat any early collection by the JVM
rdd.count()
// Test that GC causes checkpoint data cleanup after dereferencing the RDD
- postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, Seq(rddId))
+ postGCTester = new CleanerTester(sc, Seq(rddId))
rdd = null // Make RDD out of scope
runGC()
postGCTester.assertCleanup()
- assert(!fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get))
+ assert(fs.exists(ReliableRDDCheckpointData.checkpointPath(sc, rddId).get))
+ }
+
+ test("automatically clean up local checkpoint") {
+ // Note that this test is similar to the RDD cleanup
+ // test because the same underlying mechanism is used!
+ var rdd = newPairRDD().localCheckpoint()
+ assert(rdd.checkpointData.isDefined)
+ assert(rdd.checkpointData.get.checkpointRDD.isEmpty)
+ rdd.count()
+ assert(rdd.checkpointData.get.checkpointRDD.isDefined)
+
+ // Test that GC does not cause checkpoint cleanup due to a strong reference
+ val preGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id))
+ runGC()
+ intercept[Exception] {
+ preGCTester.assertCleanup()(timeout(1000 millis))
+ }
+
+ // Test that RDD going out of scope does cause the checkpoint blocks to be cleaned up
+ val postGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id))
+ rdd = null
+ runGC()
+ postGCTester.assertCleanup()
}
test("automatically cleanup RDD + shuffle + broadcast") {
@@ -408,7 +434,10 @@ class SortShuffleContextCleanerSuite extends ContextCleanerSuiteBase(classOf[Sor
}
-/** Class to test whether RDDs, shuffles, etc. have been successfully cleaned. */
+/**
+ * Class to test whether RDDs, shuffles, etc. have been successfully cleaned.
+ * The checkpoint here refers only to normal (reliable) checkpoints, not local checkpoints.
+ */
class CleanerTester(
sc: SparkContext,
rddIds: Seq[Int] = Seq.empty,
diff --git a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala
new file mode 100644
index 0000000000..5103eb74b2
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala
@@ -0,0 +1,330 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.{SparkException, SparkContext, LocalSparkContext, SparkFunSuite}
+
+import org.mockito.Mockito.spy
+import org.apache.spark.storage.{RDDBlockId, StorageLevel}
+
+/**
+ * Fine-grained tests for local checkpointing.
+ * For end-to-end tests, see CheckpointSuite.
+ */
+class LocalCheckpointSuite extends SparkFunSuite with LocalSparkContext {
+
+ override def beforeEach(): Unit = {
+ sc = new SparkContext("local[2]", "test")
+ }
+
+ test("transform storage level") {
+ val transform = LocalRDDCheckpointData.transformStorageLevel _
+ assert(transform(StorageLevel.NONE) === StorageLevel.DISK_ONLY)
+ assert(transform(StorageLevel.MEMORY_ONLY) === StorageLevel.MEMORY_AND_DISK)
+ assert(transform(StorageLevel.MEMORY_ONLY_SER) === StorageLevel.MEMORY_AND_DISK_SER)
+ assert(transform(StorageLevel.MEMORY_ONLY_2) === StorageLevel.MEMORY_AND_DISK_2)
+ assert(transform(StorageLevel.MEMORY_ONLY_SER_2) === StorageLevel.MEMORY_AND_DISK_SER_2)
+ assert(transform(StorageLevel.DISK_ONLY) === StorageLevel.DISK_ONLY)
+ assert(transform(StorageLevel.DISK_ONLY_2) === StorageLevel.DISK_ONLY_2)
+ assert(transform(StorageLevel.MEMORY_AND_DISK) === StorageLevel.MEMORY_AND_DISK)
+ assert(transform(StorageLevel.MEMORY_AND_DISK_SER) === StorageLevel.MEMORY_AND_DISK_SER)
+ assert(transform(StorageLevel.MEMORY_AND_DISK_2) === StorageLevel.MEMORY_AND_DISK_2)
+ assert(transform(StorageLevel.MEMORY_AND_DISK_SER_2) === StorageLevel.MEMORY_AND_DISK_SER_2)
+ // Off-heap is not supported and Spark should fail fast
+ intercept[SparkException] {
+ transform(StorageLevel.OFF_HEAP)
+ }
+ }
+
+ test("basic lineage truncation") {
+ val numPartitions = 4
+ val parallelRdd = sc.parallelize(1 to 100, numPartitions)
+ val mappedRdd = parallelRdd.map { i => i + 1 }
+ val filteredRdd = mappedRdd.filter { i => i % 2 == 0 }
+ val expectedPartitionIndices = (0 until numPartitions).toArray
+ assert(filteredRdd.checkpointData.isEmpty)
+ assert(filteredRdd.getStorageLevel === StorageLevel.NONE)
+ assert(filteredRdd.partitions.map(_.index) === expectedPartitionIndices)
+ assert(filteredRdd.dependencies.size === 1)
+ assert(filteredRdd.dependencies.head.rdd === mappedRdd)
+ assert(mappedRdd.dependencies.size === 1)
+ assert(mappedRdd.dependencies.head.rdd === parallelRdd)
+ assert(parallelRdd.dependencies.size === 0)
+
+ // Mark the RDD for local checkpointing
+ filteredRdd.localCheckpoint()
+ assert(filteredRdd.checkpointData.isDefined)
+ assert(!filteredRdd.checkpointData.get.isCheckpointed)
+ assert(!filteredRdd.checkpointData.get.checkpointRDD.isDefined)
+ assert(filteredRdd.getStorageLevel === LocalRDDCheckpointData.DEFAULT_STORAGE_LEVEL)
+
+ // After an action, the lineage is truncated
+ val result = filteredRdd.collect()
+ assert(filteredRdd.checkpointData.get.isCheckpointed)
+ assert(filteredRdd.checkpointData.get.checkpointRDD.isDefined)
+ val checkpointRdd = filteredRdd.checkpointData.flatMap(_.checkpointRDD).get
+ assert(filteredRdd.dependencies.size === 1)
+ assert(filteredRdd.dependencies.head.rdd === checkpointRdd)
+ assert(filteredRdd.partitions.map(_.index) === expectedPartitionIndices)
+ assert(checkpointRdd.partitions.map(_.index) === expectedPartitionIndices)
+
+ // Recomputation should yield the same result
+ assert(filteredRdd.collect() === result)
+ assert(filteredRdd.collect() === result)
+ }
+
+ test("basic lineage truncation - caching before checkpointing") {
+ testBasicLineageTruncationWithCaching(
+ newRdd.persist(StorageLevel.MEMORY_ONLY).localCheckpoint(),
+ StorageLevel.MEMORY_AND_DISK)
+ }
+
+ test("basic lineage truncation - caching after checkpointing") {
+ testBasicLineageTruncationWithCaching(
+ newRdd.localCheckpoint().persist(StorageLevel.MEMORY_ONLY),
+ StorageLevel.MEMORY_AND_DISK)
+ }
+
+ test("indirect lineage truncation") {
+ testIndirectLineageTruncation(
+ newRdd.localCheckpoint(),
+ LocalRDDCheckpointData.DEFAULT_STORAGE_LEVEL)
+ }
+
+ test("indirect lineage truncation - caching before checkpointing") {
+ testIndirectLineageTruncation(
+ newRdd.persist(StorageLevel.MEMORY_ONLY).localCheckpoint(),
+ StorageLevel.MEMORY_AND_DISK)
+ }
+
+ test("indirect lineage truncation - caching after checkpointing") {
+ testIndirectLineageTruncation(
+ newRdd.localCheckpoint().persist(StorageLevel.MEMORY_ONLY),
+ StorageLevel.MEMORY_AND_DISK)
+ }
+
+ test("checkpoint without draining iterator") {
+ testWithoutDrainingIterator(
+ newSortedRdd.localCheckpoint(),
+ LocalRDDCheckpointData.DEFAULT_STORAGE_LEVEL,
+ 50)
+ }
+
+ test("checkpoint without draining iterator - caching before checkpointing") {
+ testWithoutDrainingIterator(
+ newSortedRdd.persist(StorageLevel.MEMORY_ONLY).localCheckpoint(),
+ StorageLevel.MEMORY_AND_DISK,
+ 50)
+ }
+
+ test("checkpoint without draining iterator - caching after checkpointing") {
+ testWithoutDrainingIterator(
+ newSortedRdd.localCheckpoint().persist(StorageLevel.MEMORY_ONLY),
+ StorageLevel.MEMORY_AND_DISK,
+ 50)
+ }
+
+ test("checkpoint blocks exist") {
+ testCheckpointBlocksExist(
+ newRdd.localCheckpoint(),
+ LocalRDDCheckpointData.DEFAULT_STORAGE_LEVEL)
+ }
+
+ test("checkpoint blocks exist - caching before checkpointing") {
+ testCheckpointBlocksExist(
+ newRdd.persist(StorageLevel.MEMORY_ONLY).localCheckpoint(),
+ StorageLevel.MEMORY_AND_DISK)
+ }
+
+ test("checkpoint blocks exist - caching after checkpointing") {
+ testCheckpointBlocksExist(
+ newRdd.localCheckpoint().persist(StorageLevel.MEMORY_ONLY),
+ StorageLevel.MEMORY_AND_DISK)
+ }
+
+ test("missing checkpoint block fails with informative message") {
+ val rdd = newRdd.localCheckpoint()
+ val numPartitions = rdd.partitions.size
+ val partitionIndices = rdd.partitions.map(_.index)
+ val bmm = sc.env.blockManager.master
+
+ // After an action, the blocks should be found somewhere in the cache
+ rdd.collect()
+ partitionIndices.foreach { i =>
+ assert(bmm.contains(RDDBlockId(rdd.id, i)))
+ }
+
+ // Remove one of the blocks to simulate executor failure
+ // Collecting the RDD should now fail with an informative exception
+ val blockId = RDDBlockId(rdd.id, numPartitions - 1)
+ bmm.removeBlock(blockId)
+ try {
+ rdd.collect()
+ fail("Collect should have failed if local checkpoint block is removed...")
+ } catch {
+ case se: SparkException =>
+ assert(se.getMessage.contains(s"Checkpoint block $blockId not found"))
+ assert(se.getMessage.contains("rdd.checkpoint()")) // suggest an alternative
+ assert(se.getMessage.contains("fault-tolerant")) // justify the alternative
+ }
+ }
+
+ /**
+ * Helper method to create a simple RDD.
+ */
+ private def newRdd: RDD[Int] = {
+ sc.parallelize(1 to 100, 4)
+ .map { i => i + 1 }
+ .filter { i => i % 2 == 0 }
+ }
+
+ /**
+ * Helper method to create a simple sorted RDD.
+ */
+ private def newSortedRdd: RDD[Int] = newRdd.sortBy(identity)
+
+ /**
+ * Helper method to test basic lineage truncation with caching.
+ *
+ * @param rdd an RDD that is both marked for caching and local checkpointing
+ */
+ private def testBasicLineageTruncationWithCaching[T](
+ rdd: RDD[T],
+ targetStorageLevel: StorageLevel): Unit = {
+ require(targetStorageLevel !== StorageLevel.NONE)
+ require(rdd.getStorageLevel !== StorageLevel.NONE)
+ require(rdd.isLocallyCheckpointed)
+ val result = rdd.collect()
+ assert(rdd.getStorageLevel === targetStorageLevel)
+ assert(rdd.checkpointData.isDefined)
+ assert(rdd.checkpointData.get.isCheckpointed)
+ assert(rdd.checkpointData.get.checkpointRDD.isDefined)
+ assert(rdd.dependencies.head.rdd === rdd.checkpointData.get.checkpointRDD.get)
+ assert(rdd.collect() === result)
+ assert(rdd.collect() === result)
+ }
+
+ /**
+ * Helper method to test indirect lineage truncation.
+ *
+ * Indirect lineage truncation here means the action is called on one of the
+ * checkpointed RDD's descendants, but not on the checkpointed RDD itself.
+ *
+ * @param rdd a locally checkpointed RDD
+ */
+ private def testIndirectLineageTruncation[T](
+ rdd: RDD[T],
+ targetStorageLevel: StorageLevel): Unit = {
+ require(targetStorageLevel !== StorageLevel.NONE)
+ require(rdd.isLocallyCheckpointed)
+ val rdd1 = rdd.map { i => i + "1" }
+ val rdd2 = rdd1.map { i => i + "2" }
+ val rdd3 = rdd2.map { i => i + "3" }
+ val rddDependencies = rdd.dependencies
+ val rdd1Dependencies = rdd1.dependencies
+ val rdd2Dependencies = rdd2.dependencies
+ val rdd3Dependencies = rdd3.dependencies
+ assert(rdd1Dependencies.size === 1)
+ assert(rdd1Dependencies.head.rdd === rdd)
+ assert(rdd2Dependencies.size === 1)
+ assert(rdd2Dependencies.head.rdd === rdd1)
+ assert(rdd3Dependencies.size === 1)
+ assert(rdd3Dependencies.head.rdd === rdd2)
+
+ // Only the locally checkpointed RDD should have special storage level
+ assert(rdd.getStorageLevel === targetStorageLevel)
+ assert(rdd1.getStorageLevel === StorageLevel.NONE)
+ assert(rdd2.getStorageLevel === StorageLevel.NONE)
+ assert(rdd3.getStorageLevel === StorageLevel.NONE)
+
+ // After an action, only the dependencies of the checkpointed RDD changes
+ val result = rdd3.collect()
+ assert(rdd.dependencies !== rddDependencies)
+ assert(rdd1.dependencies === rdd1Dependencies)
+ assert(rdd2.dependencies === rdd2Dependencies)
+ assert(rdd3.dependencies === rdd3Dependencies)
+ assert(rdd3.collect() === result)
+ assert(rdd3.collect() === result)
+ }
+
+ /**
+ * Helper method to test checkpointing without fully draining the iterator.
+ *
+ * Not all RDD actions fully consume the iterator. As a result, a subset of the partitions
+ * may not be cached. However, since we want to truncate the lineage safely, we explicitly
+ * ensure that *all* partitions are fully cached. This method asserts this behavior.
+ *
+ * @param rdd a locally checkpointed RDD
+ */
+ private def testWithoutDrainingIterator[T](
+ rdd: RDD[T],
+ targetStorageLevel: StorageLevel,
+ targetCount: Int): Unit = {
+ require(targetCount > 0)
+ require(targetStorageLevel !== StorageLevel.NONE)
+ require(rdd.isLocallyCheckpointed)
+
+ // This does not drain the iterator, but checkpointing should still work
+ val first = rdd.first()
+ assert(rdd.count() === targetCount)
+ assert(rdd.count() === targetCount)
+ assert(rdd.first() === first)
+ assert(rdd.first() === first)
+
+ // Test the same thing by calling actions on a descendant instead
+ val rdd1 = rdd.repartition(10)
+ val rdd2 = rdd1.repartition(100)
+ val rdd3 = rdd2.repartition(1000)
+ val first2 = rdd3.first()
+ assert(rdd3.count() === targetCount)
+ assert(rdd3.count() === targetCount)
+ assert(rdd3.first() === first2)
+ assert(rdd3.first() === first2)
+ assert(rdd.getStorageLevel === targetStorageLevel)
+ assert(rdd1.getStorageLevel === StorageLevel.NONE)
+ assert(rdd2.getStorageLevel === StorageLevel.NONE)
+ assert(rdd3.getStorageLevel === StorageLevel.NONE)
+ }
+
+ /**
+ * Helper method to test whether the checkpoint blocks are found in the cache.
+ *
+ * @param rdd a locally checkpointed RDD
+ */
+ private def testCheckpointBlocksExist[T](
+ rdd: RDD[T],
+ targetStorageLevel: StorageLevel): Unit = {
+ val bmm = sc.env.blockManager.master
+ val partitionIndices = rdd.partitions.map(_.index)
+
+ // The blocks should not exist before the action
+ partitionIndices.foreach { i =>
+ assert(!bmm.contains(RDDBlockId(rdd.id, i)))
+ }
+
+ // After an action, the blocks should be found in the cache with the expected level
+ rdd.collect()
+ partitionIndices.foreach { i =>
+ val blockId = RDDBlockId(rdd.id, i)
+ val status = bmm.getBlockStatus(blockId)
+ assert(status.nonEmpty)
+ assert(status.values.head.storageLevel === targetStorageLevel)
+ }
+ }
+
+}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index f9384c4c3c..280aac9319 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -80,8 +80,13 @@ object MimaExcludes {
"org.apache.spark.mllib.linalg.Matrix.numActives")
) ++ Seq(
// SPARK-8914 Remove RDDApi
- ProblemFilters.exclude[MissingClassProblem](
- "org.apache.spark.sql.RDDApi")
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RDDApi")
+ ) ++ Seq(
+ // SPARK-7292 Provide operator to truncate lineage cheaply
+ ProblemFilters.exclude[AbstractClassProblem](
+ "org.apache.spark.rdd.RDDCheckpointData"),
+ ProblemFilters.exclude[AbstractClassProblem](
+ "org.apache.spark.rdd.CheckpointRDD")
) ++ Seq(
// SPARK-8701 Add input metadata in the batch page.
ProblemFilters.exclude[MissingClassProblem](