aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2016-01-27 13:27:32 -0800
committerYin Huai <yhuai@databricks.com>2016-01-27 13:27:32 -0800
commit32f741115bda5d7d7dbfcd9fe827ecbea7303ffa (patch)
tree88500d064baca2afd53b07c62708c93b0087611a /core
parent87abcf7df921a5937fdb2bae8bfb30bfabc4970a (diff)
downloadspark-32f741115bda5d7d7dbfcd9fe827ecbea7303ffa.tar.gz
spark-32f741115bda5d7d7dbfcd9fe827ecbea7303ffa.tar.bz2
spark-32f741115bda5d7d7dbfcd9fe827ecbea7303ffa.zip
[SPARK-13021][CORE] Fail fast when custom RDDs violate RDD.partition's API contract
Spark's `Partition` and `RDD.partitions` APIs have a contract which requires custom implementations of `RDD.partitions` to ensure that for all `x`, `rdd.partitions(x).index == x`; in other words, the `index` reported by a repartition needs to match its position in the partitions array. If a custom RDD implementation violates this contract, then Spark has the potential to become stuck in an infinite recomputation loop when recomputing a subset of an RDD's partitions, since the tasks that are actually run will not correspond to the missing output partitions that triggered the recomputation. Here's a link to a notebook which demonstrates this problem: https://rawgit.com/JoshRosen/e520fb9a64c1c97ec985/raw/5e8a5aa8d2a18910a1607f0aa4190104adda3424/Violating%2520RDD.partitions%2520contract.html In order to guard against this infinite loop behavior, this patch modifies Spark so that it fails fast and refuses to compute RDDs' whose `partitions` violate the API contract. Author: Josh Rosen <joshrosen@databricks.com> Closes #10932 from JoshRosen/SPARK-13021.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala18
2 files changed, 25 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 9dad794414..be47172581 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -112,6 +112,9 @@ abstract class RDD[T: ClassTag](
/**
* Implemented by subclasses to return the set of partitions in this RDD. This method will only
* be called once, so it is safe to implement a time-consuming computation in it.
+ *
+ * The partitions in this array must satisfy the following property:
+ * `rdd.partitions.zipWithIndex.forall { case (partition, index) => partition.index == index }`
*/
protected def getPartitions: Array[Partition]
@@ -237,6 +240,10 @@ abstract class RDD[T: ClassTag](
checkpointRDD.map(_.partitions).getOrElse {
if (partitions_ == null) {
partitions_ = getPartitions
+ partitions_.zipWithIndex.foreach { case (partition, index) =>
+ require(partition.index == index,
+ s"partitions($index).partition == ${partition.index}, but it should equal $index")
+ }
}
partitions_
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index ef2ed44500..80347b800a 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -914,6 +914,24 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
}
}
+ test("RDD.partitions() fails fast when partitions indicies are incorrect (SPARK-13021)") {
+ class BadRDD[T: ClassTag](prev: RDD[T]) extends RDD[T](prev) {
+
+ override def compute(part: Partition, context: TaskContext): Iterator[T] = {
+ prev.compute(part, context)
+ }
+
+ override protected def getPartitions: Array[Partition] = {
+ prev.partitions.reverse // breaks contract, which is that `rdd.partitions(i).index == i`
+ }
+ }
+ val rdd = new BadRDD(sc.parallelize(1 to 100, 100))
+ val e = intercept[IllegalArgumentException] {
+ rdd.partitions
+ }
+ assert(e.getMessage.contains("partitions"))
+ }
+
test("nested RDDs are not supported (SPARK-5063)") {
val rdd: RDD[Int] = sc.parallelize(1 to 100)
val rdd2: RDD[Int] = sc.parallelize(1 to 100)