aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-01-28 22:30:12 -0800
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-01-28 22:30:12 -0800
commit64ba6a8c2c5f46e6de6deb6a6fd576a55cb3b198 (patch)
treeddf8897a1e3e3efa02bb8dedf6e8592bb9fb2494
parentb29599e5cf0272f0d0e3ceceebb473a8163eab8c (diff)
downloadspark-64ba6a8c2c5f46e6de6deb6a6fd576a55cb3b198.tar.gz
spark-64ba6a8c2c5f46e6de6deb6a6fd576a55cb3b198.tar.bz2
spark-64ba6a8c2c5f46e6de6deb6a6fd576a55cb3b198.zip
Simplify checkpointing code and RDD class a little:
- RDD's getDependencies and getSplits methods are now guaranteed to be called only once, so subclasses can safely do computation in there without worrying about caching the results. - The management of a "splits_" variable that is cleared out when we checkpoint an RDD is now done in the RDD class. - A few of the RDD subclasses are simpler. - CheckpointRDD's compute() method no longer assumes that it is given a CheckpointRDDSplit -- it can work just as well on a split from the original RDD, because it only looks at its index. This is important because things like UnionRDD and ZippedRDD remember the parent's splits as part of their own and wouldn't work on checkpointed parents. - RDD.iterator can now reuse cached data if an RDD is computed before it is checkpointed. It seems like it wouldn't do this before (it always called iterator() on the CheckpointRDD, which read from HDFS).
-rw-r--r--core/src/main/scala/spark/CacheManager.scala6
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala4
-rw-r--r--core/src/main/scala/spark/RDD.scala130
-rw-r--r--core/src/main/scala/spark/RDDCheckpointData.scala19
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDDLike.scala2
-rw-r--r--core/src/main/scala/spark/rdd/CartesianRDD.scala12
-rw-r--r--core/src/main/scala/spark/rdd/CheckpointRDD.scala61
-rw-r--r--core/src/main/scala/spark/rdd/CoalescedRDD.scala14
-rw-r--r--core/src/main/scala/spark/rdd/MappedRDD.scala6
-rw-r--r--core/src/main/scala/spark/rdd/PartitionPruningRDD.scala13
-rw-r--r--core/src/main/scala/spark/rdd/ShuffledRDD.scala8
-rw-r--r--core/src/main/scala/spark/rdd/UnionRDD.scala14
-rw-r--r--core/src/main/scala/spark/rdd/ZippedRDD.scala7
-rw-r--r--core/src/main/scala/spark/util/MetadataCleaner.scala4
-rw-r--r--core/src/test/scala/spark/CheckpointSuite.scala21
15 files changed, 153 insertions, 168 deletions
diff --git a/core/src/main/scala/spark/CacheManager.scala b/core/src/main/scala/spark/CacheManager.scala
index a0b53fd9d6..711435c333 100644
--- a/core/src/main/scala/spark/CacheManager.scala
+++ b/core/src/main/scala/spark/CacheManager.scala
@@ -10,9 +10,9 @@ import spark.storage.{BlockManager, StorageLevel}
private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
private val loading = new HashSet[String]
- /** Gets or computes an RDD split. Used by RDD.iterator() when a RDD is cached. */
+ /** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */
def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel)
- : Iterator[T] = {
+ : Iterator[T] = {
val key = "rdd_%d_%d".format(rdd.id, split.index)
logInfo("Cache key is " + key)
blockManager.get(key) match {
@@ -50,7 +50,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
// If we got here, we have to load the split
val elements = new ArrayBuffer[Any]
logInfo("Computing partition " + split)
- elements ++= rdd.compute(split, context)
+ elements ++= rdd.computeOrReadCheckpoint(split, context)
// Try to put this block in the blockManager
blockManager.put(key, elements, storageLevel, true)
return elements.iterator.asInstanceOf[Iterator[T]]
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index 53b051f1c5..231e23a7de 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -649,9 +649,7 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
}
private[spark]
-class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U)
- extends RDD[(K, U)](prev) {
-
+class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev) {
override def getSplits = firstParent[(K, V)].splits
override val partitioner = firstParent[(K, V)].partitioner
override def compute(split: Split, context: TaskContext) =
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 0d3857f9dd..dbad6d4c83 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -1,27 +1,17 @@
package spark
-import java.io.{ObjectOutputStream, IOException, EOFException, ObjectInputStream}
import java.net.URL
import java.util.{Date, Random}
import java.util.{HashMap => JHashMap}
-import java.util.concurrent.atomic.AtomicLong
import scala.collection.Map
import scala.collection.JavaConversions.mapAsScalaMap
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
-import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
-import org.apache.hadoop.io.Writable
-import org.apache.hadoop.mapred.FileOutputCommitter
-import org.apache.hadoop.mapred.HadoopWriter
-import org.apache.hadoop.mapred.JobConf
-import org.apache.hadoop.mapred.OutputCommitter
-import org.apache.hadoop.mapred.OutputFormat
-import org.apache.hadoop.mapred.SequenceFileOutputFormat
import org.apache.hadoop.mapred.TextOutputFormat
import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
@@ -30,7 +20,6 @@ import spark.partial.BoundedDouble
import spark.partial.CountEvaluator
import spark.partial.GroupedCountEvaluator
import spark.partial.PartialResult
-import spark.rdd.BlockRDD
import spark.rdd.CartesianRDD
import spark.rdd.FilteredRDD
import spark.rdd.FlatMappedRDD
@@ -73,11 +62,11 @@ import SparkContext._
* on RDD internals.
*/
abstract class RDD[T: ClassManifest](
- @transient var sc: SparkContext,
- var dependencies_ : List[Dependency[_]]
+ @transient private var sc: SparkContext,
+ @transient private var deps: Seq[Dependency[_]]
) extends Serializable with Logging {
-
+ /** Construct an RDD with just a one-to-one dependency on one parent */
def this(@transient oneParent: RDD[_]) =
this(oneParent.context , List(new OneToOneDependency(oneParent)))
@@ -85,25 +74,27 @@ abstract class RDD[T: ClassManifest](
// Methods that should be implemented by subclasses of RDD
// =======================================================================
- /** Function for computing a given partition. */
+ /** Implemented by subclasses to compute a given partition. */
def compute(split: Split, context: TaskContext): Iterator[T]
- /** Set of partitions in this RDD. */
- protected def getSplits(): Array[Split]
+ /**
+ * Implemented by subclasses to return the set of partitions in this RDD. This method will only
+ * be called once, so it is safe to implement a time-consuming computation in it.
+ */
+ protected def getSplits: Array[Split]
- /** How this RDD depends on any parent RDDs. */
- protected def getDependencies(): List[Dependency[_]] = dependencies_
+ /**
+ * Implemented by subclasses to return how this RDD depends on parent RDDs. This method will only
+ * be called once, so it is safe to implement a time-consuming computation in it.
+ */
+ protected def getDependencies: Seq[Dependency[_]] = deps
- /** A friendly name for this RDD */
- var name: String = null
-
/** Optionally overridden by subclasses to specify placement preferences. */
protected def getPreferredLocations(split: Split): Seq[String] = Nil
/** Optionally overridden by subclasses to specify how they are partitioned. */
val partitioner: Option[Partitioner] = None
-
// =======================================================================
// Methods and fields available on all RDDs
// =======================================================================
@@ -111,13 +102,16 @@ abstract class RDD[T: ClassManifest](
/** A unique ID for this RDD (within its SparkContext). */
val id = sc.newRddId()
+ /** A friendly name for this RDD */
+ var name: String = null
+
/** Assign a name to this RDD */
def setName(_name: String) = {
name = _name
this
}
- /**
+ /**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD.
*/
@@ -142,15 +136,24 @@ abstract class RDD[T: ClassManifest](
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel = storageLevel
+ // Our dependencies and splits will be gotten by calling subclass's methods below, and will
+ // be overwritten when we're checkpointed
+ private var dependencies_ : Seq[Dependency[_]] = null
+ @transient private var splits_ : Array[Split] = null
+
+ /** An Option holding our checkpoint RDD, if we are checkpointed */
+ private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD)
+
/**
- * Get the preferred location of a split, taking into account whether the
+ * Get the list of dependencies of this RDD, taking into account whether the
* RDD is checkpointed or not.
*/
- final def preferredLocations(split: Split): Seq[String] = {
- if (isCheckpointed) {
- checkpointData.get.getPreferredLocations(split)
- } else {
- getPreferredLocations(split)
+ final def dependencies: Seq[Dependency[_]] = {
+ checkpointRDD.map(r => List(new OneToOneDependency(r))).getOrElse {
+ if (dependencies_ == null) {
+ dependencies_ = getDependencies
+ }
+ dependencies_
}
}
@@ -159,22 +162,21 @@ abstract class RDD[T: ClassManifest](
* RDD is checkpointed or not.
*/
final def splits: Array[Split] = {
- if (isCheckpointed) {
- checkpointData.get.getSplits
- } else {
- getSplits
+ checkpointRDD.map(_.splits).getOrElse {
+ if (splits_ == null) {
+ splits_ = getSplits
+ }
+ splits_
}
}
/**
- * Get the list of dependencies of this RDD, taking into account whether the
+ * Get the preferred location of a split, taking into account whether the
* RDD is checkpointed or not.
*/
- final def dependencies: List[Dependency[_]] = {
- if (isCheckpointed) {
- dependencies_
- } else {
- getDependencies
+ final def preferredLocations(split: Split): Seq[String] = {
+ checkpointRDD.map(_.getPreferredLocations(split)).getOrElse {
+ getPreferredLocations(split)
}
}
@@ -184,11 +186,20 @@ abstract class RDD[T: ClassManifest](
* subclasses of RDD.
*/
final def iterator(split: Split, context: TaskContext): Iterator[T] = {
- if (isCheckpointed) {
- checkpointData.get.iterator(split, context)
- } else if (storageLevel != StorageLevel.NONE) {
+ if (storageLevel != StorageLevel.NONE) {
SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
} else {
+ computeOrReadCheckpoint(split, context)
+ }
+ }
+
+ /**
+ * Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing.
+ */
+ private[spark] def computeOrReadCheckpoint(split: Split, context: TaskContext): Iterator[T] = {
+ if (isCheckpointed) {
+ firstParent[T].iterator(split, context)
+ } else {
compute(split, context)
}
}
@@ -578,15 +589,15 @@ abstract class RDD[T: ClassManifest](
/**
* Return whether this RDD has been checkpointed or not
*/
- def isCheckpointed(): Boolean = {
- if (checkpointData.isDefined) checkpointData.get.isCheckpointed() else false
+ def isCheckpointed: Boolean = {
+ checkpointData.map(_.isCheckpointed).getOrElse(false)
}
/**
* Gets the name of the file to which this RDD was checkpointed
*/
- def getCheckpointFile(): Option[String] = {
- if (checkpointData.isDefined) checkpointData.get.getCheckpointFile() else None
+ def getCheckpointFile: Option[String] = {
+ checkpointData.flatMap(_.getCheckpointFile)
}
// =======================================================================
@@ -611,31 +622,36 @@ abstract class RDD[T: ClassManifest](
def context = sc
/**
- * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler
+ * Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler
* after a job using this RDD has completed (therefore the RDD has been materialized and
* potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs.
*/
- protected[spark] def doCheckpoint() {
- if (checkpointData.isDefined) checkpointData.get.doCheckpoint()
- dependencies.foreach(_.rdd.doCheckpoint())
+ private[spark] def doCheckpoint() {
+ if (checkpointData.isDefined) {
+ checkpointData.get.doCheckpoint()
+ } else {
+ dependencies.foreach(_.rdd.doCheckpoint())
+ }
}
/**
- * Changes the dependencies of this RDD from its original parents to the new RDD
- * (`newRDD`) created from the checkpoint file.
+ * Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`)
+ * created from the checkpoint file, and forget its old dependencies and splits.
*/
- protected[spark] def changeDependencies(newRDD: RDD[_]) {
+ private[spark] def markCheckpointed(checkpointRDD: RDD[_]) {
clearDependencies()
- dependencies_ = List(new OneToOneDependency(newRDD))
+ dependencies_ = null
+ splits_ = null
+ deps = null // Forget the constructor argument for dependencies too
}
/**
* Clears the dependencies of this RDD. This method must ensure that all references
* to the original parent RDDs is removed to enable the parent RDDs to be garbage
* collected. Subclasses of RDD may override this method for implementing their own cleaning
- * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea.
+ * logic. See [[spark.rdd.UnionRDD]] for an example.
*/
- protected[spark] def clearDependencies() {
+ protected def clearDependencies() {
dependencies_ = null
}
}
diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala
index 18df530b7d..a4a4ebaf53 100644
--- a/core/src/main/scala/spark/RDDCheckpointData.scala
+++ b/core/src/main/scala/spark/RDDCheckpointData.scala
@@ -20,7 +20,7 @@ private[spark] object CheckpointState extends Enumeration {
* of the checkpointed RDD.
*/
private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
-extends Logging with Serializable {
+ extends Logging with Serializable {
import CheckpointState._
@@ -31,7 +31,7 @@ extends Logging with Serializable {
@transient var cpFile: Option[String] = None
// The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD.
- @transient var cpRDD: Option[RDD[T]] = None
+ var cpRDD: Option[RDD[T]] = None
// Mark the RDD for checkpointing
def markForCheckpoint() {
@@ -41,12 +41,12 @@ extends Logging with Serializable {
}
// Is the RDD already checkpointed
- def isCheckpointed(): Boolean = {
+ def isCheckpointed: Boolean = {
RDDCheckpointData.synchronized { cpState == Checkpointed }
}
// Get the file to which this RDD was checkpointed to as an Option
- def getCheckpointFile(): Option[String] = {
+ def getCheckpointFile: Option[String] = {
RDDCheckpointData.synchronized { cpFile }
}
@@ -71,7 +71,7 @@ extends Logging with Serializable {
RDDCheckpointData.synchronized {
cpFile = Some(path)
cpRDD = Some(newRDD)
- rdd.changeDependencies(newRDD)
+ rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and splits
cpState = Checkpointed
RDDCheckpointData.clearTaskCaches()
logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id)
@@ -79,7 +79,7 @@ extends Logging with Serializable {
}
// Get preferred location of a split after checkpointing
- def getPreferredLocations(split: Split) = {
+ def getPreferredLocations(split: Split): Seq[String] = {
RDDCheckpointData.synchronized {
cpRDD.get.preferredLocations(split)
}
@@ -91,9 +91,10 @@ extends Logging with Serializable {
}
}
- // Get iterator. This is called at the worker nodes.
- def iterator(split: Split, context: TaskContext): Iterator[T] = {
- rdd.firstParent[T].iterator(split, context)
+ def checkpointRDD: Option[RDD[T]] = {
+ RDDCheckpointData.synchronized {
+ cpRDD
+ }
}
}
diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
index 4c95c989b5..46fd8fe85e 100644
--- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
@@ -319,7 +319,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround
/**
* Return whether this RDD has been checkpointed or not
*/
- def isCheckpointed(): Boolean = rdd.isCheckpointed()
+ def isCheckpointed: Boolean = rdd.isCheckpointed
/**
* Gets the name of the file to which this RDD was checkpointed
diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala
index 453d410ad4..0f9ca06531 100644
--- a/core/src/main/scala/spark/rdd/CartesianRDD.scala
+++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala
@@ -1,7 +1,7 @@
package spark.rdd
import java.io.{ObjectOutputStream, IOException}
-import spark.{OneToOneDependency, NarrowDependency, RDD, SparkContext, Split, TaskContext}
+import spark._
private[spark]
@@ -35,7 +35,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
val numSplitsInRdd2 = rdd2.splits.size
- @transient var splits_ = {
+ override def getSplits: Array[Split] = {
// create the cross product split
val array = new Array[Split](rdd1.splits.size * rdd2.splits.size)
for (s1 <- rdd1.splits; s2 <- rdd2.splits) {
@@ -45,8 +45,6 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
array
}
- override def getSplits = splits_
-
override def getPreferredLocations(split: Split) = {
val currSplit = split.asInstanceOf[CartesianSplit]
rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
@@ -58,7 +56,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
}
- var deps_ = List(
+ override def getDependencies: Seq[Dependency[_]] = List(
new NarrowDependency(rdd1) {
def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2)
},
@@ -67,11 +65,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
}
)
- override def getDependencies = deps_
-
override def clearDependencies() {
- deps_ = Nil
- splits_ = null
rdd1 = null
rdd2 = null
}
diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
index 6f00f6ac73..96b593ba7c 100644
--- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
@@ -9,23 +9,26 @@ import org.apache.hadoop.fs.Path
import java.io.{File, IOException, EOFException}
import java.text.NumberFormat
-private[spark] class CheckpointRDDSplit(idx: Int, val splitFile: String) extends Split {
- override val index: Int = idx
-}
+private[spark] class CheckpointRDDSplit(val index: Int) extends Split {}
/**
* This RDD represents a RDD checkpoint file (similar to HadoopRDD).
*/
private[spark]
-class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String)
+class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: String)
extends RDD[T](sc, Nil) {
- @transient val path = new Path(checkpointPath)
- @transient val fs = path.getFileSystem(new Configuration())
+ @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
@transient val splits_ : Array[Split] = {
- val splitFiles = fs.listStatus(path).map(_.getPath.toString).filter(_.contains("part-")).sorted
- splitFiles.zipWithIndex.map(x => new CheckpointRDDSplit(x._2, x._1)).toArray
+ val dirContents = fs.listStatus(new Path(checkpointPath))
+ val splitFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted
+ val numSplits = splitFiles.size
+ if (!splitFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
+ !splitFiles(numSplits-1).endsWith(CheckpointRDD.splitIdToFile(numSplits-1))) {
+ throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
+ }
+ Array.tabulate(numSplits)(i => new CheckpointRDDSplit(i))
}
checkpointData = Some(new RDDCheckpointData[T](this))
@@ -34,36 +37,34 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String)
override def getSplits = splits_
override def getPreferredLocations(split: Split): Seq[String] = {
- val status = fs.getFileStatus(path)
+ val status = fs.getFileStatus(new Path(checkpointPath))
val locations = fs.getFileBlockLocations(status, 0, status.getLen)
- locations.firstOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
+ locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
}
override def compute(split: Split, context: TaskContext): Iterator[T] = {
- CheckpointRDD.readFromFile(split.asInstanceOf[CheckpointRDDSplit].splitFile, context)
+ val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index))
+ CheckpointRDD.readFromFile(file, context)
}
override def checkpoint() {
- // Do nothing. Hadoop RDD should not be checkpointed.
+ // Do nothing. CheckpointRDD should not be checkpointed.
}
}
private[spark] object CheckpointRDD extends Logging {
- def splitIdToFileName(splitId: Int): String = {
- val numfmt = NumberFormat.getInstance()
- numfmt.setMinimumIntegerDigits(5)
- numfmt.setGroupingUsed(false)
- "part-" + numfmt.format(splitId)
+ def splitIdToFile(splitId: Int): String = {
+ "part-%05d".format(splitId)
}
- def writeToFile[T](path: String, blockSize: Int = -1)(context: TaskContext, iterator: Iterator[T]) {
+ def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
val outputDir = new Path(path)
val fs = outputDir.getFileSystem(new Configuration())
- val finalOutputName = splitIdToFileName(context.splitId)
+ val finalOutputName = splitIdToFile(ctx.splitId)
val finalOutputPath = new Path(outputDir, finalOutputName)
- val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptId)
+ val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId)
if (fs.exists(tempOutputPath)) {
throw new IOException("Checkpoint failed: temporary path " +
@@ -83,22 +84,22 @@ private[spark] object CheckpointRDD extends Logging {
serializeStream.close()
if (!fs.rename(tempOutputPath, finalOutputPath)) {
- if (!fs.delete(finalOutputPath, true)) {
- throw new IOException("Checkpoint failed: failed to delete earlier output of task "
- + context.attemptId)
- }
- if (!fs.rename(tempOutputPath, finalOutputPath)) {
+ if (!fs.exists(finalOutputPath)) {
+ fs.delete(tempOutputPath, false)
throw new IOException("Checkpoint failed: failed to save output of task: "
- + context.attemptId)
+ + ctx.attemptId + " and final output path does not exist")
+ } else {
+ // Some other copy of this task must've finished before us and renamed it
+ logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it")
+ fs.delete(tempOutputPath, false)
}
}
}
- def readFromFile[T](path: String, context: TaskContext): Iterator[T] = {
- val inputPath = new Path(path)
- val fs = inputPath.getFileSystem(new Configuration())
+ def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = {
+ val fs = path.getFileSystem(new Configuration())
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
- val fileInputStream = fs.open(inputPath, bufferSize)
+ val fileInputStream = fs.open(path, bufferSize)
val serializer = SparkEnv.get.serializer.newInstance()
val deserializeStream = serializer.deserializeStream(fileInputStream)
diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
index 167755bbba..4c57434b65 100644
--- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
@@ -27,11 +27,11 @@ private[spark] case class CoalescedRDDSplit(
* or to avoid having a large number of small tasks when processing a directory with many files.
*/
class CoalescedRDD[T: ClassManifest](
- var prev: RDD[T],
+ @transient var prev: RDD[T],
maxPartitions: Int)
- extends RDD[T](prev.context, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs
+ extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies
- @transient var splits_ : Array[Split] = {
+ override def getSplits: Array[Split] = {
val prevSplits = prev.splits
if (prevSplits.length < maxPartitions) {
prevSplits.map(_.index).map{idx => new CoalescedRDDSplit(idx, prev, Array(idx)) }
@@ -44,26 +44,20 @@ class CoalescedRDD[T: ClassManifest](
}
}
- override def getSplits = splits_
-
override def compute(split: Split, context: TaskContext): Iterator[T] = {
split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { parentSplit =>
firstParent[T].iterator(parentSplit, context)
}
}
- var deps_ : List[Dependency[_]] = List(
+ override def getDependencies: Seq[Dependency[_]] = List(
new NarrowDependency(prev) {
def getParents(id: Int): Seq[Int] =
splits(id).asInstanceOf[CoalescedRDDSplit].parentsIndices
}
)
- override def getDependencies() = deps_
-
override def clearDependencies() {
- deps_ = Nil
- splits_ = null
prev = null
}
}
diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala
index c6ceb272cd..5466c9c657 100644
--- a/core/src/main/scala/spark/rdd/MappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/MappedRDD.scala
@@ -3,13 +3,11 @@ package spark.rdd
import spark.{RDD, Split, TaskContext}
private[spark]
-class MappedRDD[U: ClassManifest, T: ClassManifest](
- prev: RDD[T],
- f: T => U)
+class MappedRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: T => U)
extends RDD[U](prev) {
override def getSplits = firstParent[T].splits
override def compute(split: Split, context: TaskContext) =
firstParent[T].iterator(split, context).map(f)
-} \ No newline at end of file
+}
diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala
index 97dd37950e..b8482338c6 100644
--- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala
+++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala
@@ -7,23 +7,18 @@ import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext}
* all partitions. An example use case: If we know the RDD is partitioned by range,
* and the execution DAG has a filter on the key, we can avoid launching tasks
* on partitions that don't have the range covering the key.
+ *
+ * TODO: This currently doesn't give partition IDs properly!
*/
class PartitionPruningRDD[T: ClassManifest](
@transient prev: RDD[T],
@transient partitionFilterFunc: Int => Boolean)
extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) {
- @transient
- var partitions_ : Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].partitions
-
override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context)
- override protected def getSplits = partitions_
+ override protected def getSplits =
+ getDependencies.head.asInstanceOf[PruneDependency[T]].partitions
override val partitioner = firstParent[T].partitioner
-
- override def clearDependencies() {
- super.clearDependencies()
- partitions_ = null
- }
}
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index 28ff19876d..d396478673 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -22,16 +22,10 @@ class ShuffledRDD[K, V](
override val partitioner = Some(part)
- @transient var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
-
- override def getSplits = splits_
+ override def getSplits = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
override def compute(split: Split, context: TaskContext): Iterator[(K, V)] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index)
}
-
- override def clearDependencies() {
- splits_ = null
- }
}
diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala
index 82f0a44ecd..26a2d511f2 100644
--- a/core/src/main/scala/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/spark/rdd/UnionRDD.scala
@@ -26,9 +26,9 @@ private[spark] class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], splitIn
class UnionRDD[T: ClassManifest](
sc: SparkContext,
@transient var rdds: Seq[RDD[T]])
- extends RDD[T](sc, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs
+ extends RDD[T](sc, Nil) { // Nil since we implement getDependencies
- @transient var splits_ : Array[Split] = {
+ override def getSplits: Array[Split] = {
val array = new Array[Split](rdds.map(_.splits.size).sum)
var pos = 0
for (rdd <- rdds; split <- rdd.splits) {
@@ -38,20 +38,16 @@ class UnionRDD[T: ClassManifest](
array
}
- override def getSplits = splits_
-
- @transient var deps_ = {
+ override def getDependencies: Seq[Dependency[_]] = {
val deps = new ArrayBuffer[Dependency[_]]
var pos = 0
for (rdd <- rdds) {
deps += new RangeDependency(rdd, 0, pos, rdd.splits.size)
pos += rdd.splits.size
}
- deps.toList
+ deps
}
- override def getDependencies = deps_
-
override def compute(s: Split, context: TaskContext): Iterator[T] =
s.asInstanceOf[UnionSplit[T]].iterator(context)
@@ -59,8 +55,6 @@ class UnionRDD[T: ClassManifest](
s.asInstanceOf[UnionSplit[T]].preferredLocations()
override def clearDependencies() {
- deps_ = null
- splits_ = null
rdds = null
}
}
diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala
index d950b06c85..e5df6d8c72 100644
--- a/core/src/main/scala/spark/rdd/ZippedRDD.scala
+++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala
@@ -32,9 +32,7 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest](
extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2)))
with Serializable {
- // TODO: FIX THIS.
-
- @transient var splits_ : Array[Split] = {
+ override def getSplits: Array[Split] = {
if (rdd1.splits.size != rdd2.splits.size) {
throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
}
@@ -45,8 +43,6 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest](
array
}
- override def getSplits = splits_
-
override def compute(s: Split, context: TaskContext): Iterator[(T, U)] = {
val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits
rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context))
@@ -58,7 +54,6 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest](
}
override def clearDependencies() {
- splits_ = null
rdd1 = null
rdd2 = null
}
diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala
index 6cf93a9b17..eaff7ae581 100644
--- a/core/src/main/scala/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/spark/util/MetadataCleaner.scala
@@ -26,8 +26,8 @@ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging
if (delaySeconds > 0) {
logDebug(
- "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds and "
- + "period of " + periodSeconds + " secs")
+ "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds " +
+ "and period of " + periodSeconds + " secs")
timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000)
}
diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala
index 33c317720c..0b74607fb8 100644
--- a/core/src/test/scala/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/spark/CheckpointSuite.scala
@@ -99,7 +99,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
// the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits.
// Note that this test is very specific to the current implementation of CartesianRDD.
val ones = sc.makeRDD(1 to 100, 10).map(x => x)
- ones.checkpoint // checkpoint that MappedRDD
+ ones.checkpoint() // checkpoint that MappedRDD
val cartesian = new CartesianRDD(sc, ones, ones)
val splitBeforeCheckpoint =
serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit])
@@ -125,7 +125,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
// the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits.
// Note that this test is very specific to the current implementation of CoalescedRDDSplits
val ones = sc.makeRDD(1 to 100, 10).map(x => x)
- ones.checkpoint // checkpoint that MappedRDD
+ ones.checkpoint() // checkpoint that MappedRDD
val coalesced = new CoalescedRDD(ones, 2)
val splitBeforeCheckpoint =
serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit])
@@ -160,7 +160,6 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
// so only the RDD will reduce in serialized size, not the splits.
testParentCheckpointing(
rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
-
}
/**
@@ -176,7 +175,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
testRDDSplitSize: Boolean = false
) {
// Generate the final RDD using given RDD operation
- val baseRDD = generateLongLineageRDD
+ val baseRDD = generateLongLineageRDD()
val operatedRDD = op(baseRDD)
val parentRDD = operatedRDD.dependencies.headOption.orNull
val rddType = operatedRDD.getClass.getSimpleName
@@ -245,12 +244,16 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
testRDDSplitSize: Boolean
) {
// Generate the final RDD using given RDD operation
- val baseRDD = generateLongLineageRDD
+ val baseRDD = generateLongLineageRDD()
val operatedRDD = op(baseRDD)
val parentRDD = operatedRDD.dependencies.head.rdd
val rddType = operatedRDD.getClass.getSimpleName
val parentRDDType = parentRDD.getClass.getSimpleName
+ // Get the splits and dependencies of the parent in case they're lazily computed
+ parentRDD.dependencies
+ parentRDD.splits
+
// Find serialized sizes before and after the checkpoint
val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
parentRDD.checkpoint() // checkpoint the parent RDD, not the generated one
@@ -267,7 +270,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
if (testRDDSize) {
assert(
rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
- "Size of " + rddType + " did not reduce after parent checkpointing parent " + parentRDDType +
+ "Size of " + rddType + " did not reduce after checkpointing parent " + parentRDDType +
"[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
)
}
@@ -318,10 +321,12 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
}
/**
- * Get serialized sizes of the RDD and its splits
+ * Get serialized sizes of the RDD and its splits, in order to test whether the size shrinks
+ * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint.
*/
def getSerializedSizes(rdd: RDD[_]): (Int, Int) = {
- (Utils.serialize(rdd).size, Utils.serialize(rdd.splits).size)
+ (Utils.serialize(rdd).length - Utils.serialize(rdd.checkpointData).length,
+ Utils.serialize(rdd.splits).length)
}
/**