aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala18
2 files changed, 25 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 9dad794414..be47172581 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -112,6 +112,9 @@ abstract class RDD[T: ClassTag](
/**
* Implemented by subclasses to return the set of partitions in this RDD. This method will only
* be called once, so it is safe to implement a time-consuming computation in it.
+ *
+ * The partitions in this array must satisfy the following property:
+ * `rdd.partitions.zipWithIndex.forall { case (partition, index) => partition.index == index }`
*/
protected def getPartitions: Array[Partition]
@@ -237,6 +240,10 @@ abstract class RDD[T: ClassTag](
checkpointRDD.map(_.partitions).getOrElse {
if (partitions_ == null) {
partitions_ = getPartitions
+ partitions_.zipWithIndex.foreach { case (partition, index) =>
+ require(partition.index == index,
+ s"partitions($index).partition == ${partition.index}, but it should equal $index")
+ }
}
partitions_
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index ef2ed44500..80347b800a 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -914,6 +914,24 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
}
}
+ test("RDD.partitions() fails fast when partitions indicies are incorrect (SPARK-13021)") {
+ class BadRDD[T: ClassTag](prev: RDD[T]) extends RDD[T](prev) {
+
+ override def compute(part: Partition, context: TaskContext): Iterator[T] = {
+ prev.compute(part, context)
+ }
+
+ override protected def getPartitions: Array[Partition] = {
+ prev.partitions.reverse // breaks contract, which is that `rdd.partitions(i).index == i`
+ }
+ }
+ val rdd = new BadRDD(sc.parallelize(1 to 100, 100))
+ val e = intercept[IllegalArgumentException] {
+ rdd.partitions
+ }
+ assert(e.getMessage.contains("partitions"))
+ }
+
test("nested RDDs are not supported (SPARK-5063)") {
val rdd: RDD[Int] = sc.parallelize(1 to 100)
val rdd2: RDD[Int] = sc.parallelize(1 to 100)