diff options
Diffstat (limited to 'core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala')
-rw-r--r-- | core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 3b56e45aa9..bc688110f4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag import org.apache.hadoop.fs.Path import org.apache.hadoop.conf.Configuration -import org.apache.spark.{Partition, SparkException, Logging} +import org.apache.spark.{SerializableWritable, Partition, SparkException, Logging} import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask} /** @@ -40,7 +40,7 @@ private[spark] object CheckpointState extends Enumeration { * manages the post-checkpoint state by providing the updated partitions, iterator and preferred locations * of the checkpointed RDD. */ -private[spark] class RDDCheckpointData[T: ClassTag](rdd: RDD[T]) +private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) extends Logging with Serializable { import CheckpointState._ @@ -85,14 +85,21 @@ private[spark] class RDDCheckpointData[T: ClassTag](rdd: RDD[T]) // Create the output path for the checkpoint val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id) - val fs = path.getFileSystem(new Configuration()) + val fs = path.getFileSystem(rdd.context.hadoopConfiguration) if (!fs.mkdirs(path)) { throw new SparkException("Failed to create checkpoint path " + path) } // Save to file, and reload it as an RDD - rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString) _) + val broadcastedConf = rdd.context.broadcast( + new SerializableWritable(rdd.context.hadoopConfiguration)) + rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString, broadcastedConf) _) val newRDD = new CheckpointRDD[T](rdd.context, path.toString) + if (newRDD.partitions.size != rdd.partitions.size) { + throw new SparkException( + "Checkpoint RDD " + newRDD + "("+ newRDD.partitions.size + ") has different " + + "number of partitions than original RDD " + rdd + "(" + rdd.partitions.size + ")") + } // Change the dependencies and partitions of the RDD RDDCheckpointData.synchronized { @@ -101,8 +108,8 @@ private[spark] class RDDCheckpointData[T: ClassTag](rdd: RDD[T]) rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions cpState = Checkpointed RDDCheckpointData.clearTaskCaches() - logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) } + logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id) } // Get preferred location of a split after checkpointing |