aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala')
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala122
1 files changed, 116 insertions, 6 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
index a69be6a068..fa71b8c262 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
@@ -20,12 +20,12 @@ package org.apache.spark.rdd
import java.io.IOException
import scala.reflect.ClassTag
+import scala.util.control.NonFatal
import org.apache.hadoop.fs.Path
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.util.{SerializableConfiguration, Utils}
/**
@@ -33,8 +33,9 @@ import org.apache.spark.util.{SerializableConfiguration, Utils}
*/
private[spark] class ReliableCheckpointRDD[T: ClassTag](
sc: SparkContext,
- val checkpointPath: String)
- extends CheckpointRDD[T](sc) {
+ val checkpointPath: String,
+ _partitioner: Option[Partitioner] = None
+ ) extends CheckpointRDD[T](sc) {
@transient private val hadoopConf = sc.hadoopConfiguration
@transient private val cpath = new Path(checkpointPath)
@@ -47,7 +48,13 @@ private[spark] class ReliableCheckpointRDD[T: ClassTag](
/**
* Return the path of the checkpoint directory this RDD reads data from.
*/
- override def getCheckpointFile: Option[String] = Some(checkpointPath)
+ override val getCheckpointFile: Option[String] = Some(checkpointPath)
+
+ override val partitioner: Option[Partitioner] = {
+ _partitioner.orElse {
+ ReliableCheckpointRDD.readCheckpointedPartitionerFile(context, checkpointPath)
+ }
+ }
/**
* Return partitions described by the files in the checkpoint directory.
@@ -100,10 +107,52 @@ private[spark] object ReliableCheckpointRDD extends Logging {
"part-%05d".format(partitionIndex)
}
+ private def checkpointPartitionerFileName(): String = {
+ "_partitioner"
+ }
+
+ /**
+ * Write RDD to checkpoint files and return a ReliableCheckpointRDD representing the RDD.
+ */
+ def writeRDDToCheckpointDirectory[T: ClassTag](
+ originalRDD: RDD[T],
+ checkpointDir: String,
+ blockSize: Int = -1): ReliableCheckpointRDD[T] = {
+
+ val sc = originalRDD.sparkContext
+
+ // Create the output path for the checkpoint
+ val checkpointDirPath = new Path(checkpointDir)
+ val fs = checkpointDirPath.getFileSystem(sc.hadoopConfiguration)
+ if (!fs.mkdirs(checkpointDirPath)) {
+ throw new SparkException(s"Failed to create checkpoint path $checkpointDirPath")
+ }
+
+ // Save to file, and reload it as an RDD
+ val broadcastedConf = sc.broadcast(
+ new SerializableConfiguration(sc.hadoopConfiguration))
+ // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
+ sc.runJob(originalRDD,
+ writePartitionToCheckpointFile[T](checkpointDirPath.toString, broadcastedConf) _)
+
+ if (originalRDD.partitioner.nonEmpty) {
+ writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath)
+ }
+
+ val newRDD = new ReliableCheckpointRDD[T](
+ sc, checkpointDirPath.toString, originalRDD.partitioner)
+ if (newRDD.partitions.length != originalRDD.partitions.length) {
+ throw new SparkException(
+ s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " +
+ s"number of partitions from original RDD $originalRDD(${originalRDD.partitions.length})")
+ }
+ newRDD
+ }
+
/**
- * Write this partition's values to a checkpoint file.
+ * Write a RDD partition's data to a checkpoint file.
*/
- def writeCheckpointFile[T: ClassTag](
+ def writePartitionToCheckpointFile[T: ClassTag](
path: String,
broadcastedConf: Broadcast[SerializableConfiguration],
blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
@@ -152,6 +201,67 @@ private[spark] object ReliableCheckpointRDD extends Logging {
}
/**
+ * Write a partitioner to the given RDD checkpoint directory. This is done on a best-effort
+ * basis; any exception while writing the partitioner is caught, logged and ignored.
+ */
+ private def writePartitionerToCheckpointDir(
+ sc: SparkContext, partitioner: Partitioner, checkpointDirPath: Path): Unit = {
+ try {
+ val partitionerFilePath = new Path(checkpointDirPath, checkpointPartitionerFileName)
+ val bufferSize = sc.conf.getInt("spark.buffer.size", 65536)
+ val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration)
+ val fileOutputStream = fs.create(partitionerFilePath, false, bufferSize)
+ val serializer = SparkEnv.get.serializer.newInstance()
+ val serializeStream = serializer.serializeStream(fileOutputStream)
+ Utils.tryWithSafeFinally {
+ serializeStream.writeObject(partitioner)
+ } {
+ serializeStream.close()
+ }
+ logDebug(s"Written partitioner to $partitionerFilePath")
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Error writing partitioner $partitioner to $checkpointDirPath")
+ }
+ }
+
+
+ /**
+ * Read a partitioner from the given RDD checkpoint directory, if it exists.
+ * This is done on a best-effort basis; any exception while reading the partitioner is
+ * caught, logged and ignored.
+ */
+ private def readCheckpointedPartitionerFile(
+ sc: SparkContext,
+ checkpointDirPath: String): Option[Partitioner] = {
+ try {
+ val bufferSize = sc.conf.getInt("spark.buffer.size", 65536)
+ val partitionerFilePath = new Path(checkpointDirPath, checkpointPartitionerFileName)
+ val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration)
+ if (fs.exists(partitionerFilePath)) {
+ val fileInputStream = fs.open(partitionerFilePath, bufferSize)
+ val serializer = SparkEnv.get.serializer.newInstance()
+ val deserializeStream = serializer.deserializeStream(fileInputStream)
+ val partitioner = Utils.tryWithSafeFinally[Partitioner] {
+ deserializeStream.readObject[Partitioner]
+ } {
+ deserializeStream.close()
+ }
+ logDebug(s"Read partitioner from $partitionerFilePath")
+ Some(partitioner)
+ } else {
+ logDebug("No partitioner file")
+ None
+ }
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Error reading partitioner from $checkpointDirPath, " +
+ s"partitioner will not be recovered which may lead to performance loss", e)
+ None
+ }
+ }
+
+ /**
* Read the content of the specified checkpoint file.
*/
def readCheckpointFile[T](