aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2014-07-30 21:30:13 -0700
committerReynold Xin <rxin@apache.org>2014-07-30 21:30:13 -0700
commit894d48ffb8c91e347ab60c58de983e1aaf181188 (patch)
treef95d3a79b5e8fa3a92e7293ac2f22464ae1f8ebb /core
parente966284409f9355e1169960e73a2215617c8cb22 (diff)
downloadspark-894d48ffb8c91e347ab60c58de983e1aaf181188.tar.gz
spark-894d48ffb8c91e347ab60c58de983e1aaf181188.tar.bz2
spark-894d48ffb8c91e347ab60c58de983e1aaf181188.zip
[SPARK-2758] UnionRDD's UnionPartition should not reference parent RDDs
Author: Reynold Xin <rxin@apache.org> Closes #1675 from rxin/unionrdd and squashes the following commits: 941d316 [Reynold Xin] Clear RDDs for checkpointing. c9f05f2 [Reynold Xin] [SPARK-2758] UnionRDD's UnionPartition should not reference parent RDDs
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala41
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala12
2 files changed, 42 insertions, 11 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
index 21c6e07d69..197167ecad 100644
--- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
@@ -25,21 +25,32 @@ import scala.reflect.ClassTag
import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
import org.apache.spark.annotation.DeveloperApi
-private[spark] class UnionPartition[T: ClassTag](idx: Int, rdd: RDD[T], splitIndex: Int)
+/**
+ * Partition for UnionRDD.
+ *
+ * @param idx index of the partition
+ * @param rdd the parent RDD this partition refers to
+ * @param parentRddIndex index of the parent RDD this partition refers to
+ * @param parentRddPartitionIndex index of the partition within the parent RDD
+ * this partition refers to
+ */
+private[spark] class UnionPartition[T: ClassTag](
+ idx: Int,
+ @transient rdd: RDD[T],
+ val parentRddIndex: Int,
+ @transient parentRddPartitionIndex: Int)
extends Partition {
- var split: Partition = rdd.partitions(splitIndex)
-
- def iterator(context: TaskContext) = rdd.iterator(split, context)
+ var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex)
- def preferredLocations() = rdd.preferredLocations(split)
+ def preferredLocations() = rdd.preferredLocations(parentPartition)
override val index: Int = idx
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream) {
// Update the reference to parent split at the time of task serialization
- split = rdd.partitions(splitIndex)
+ parentPartition = rdd.partitions(parentRddPartitionIndex)
oos.defaultWriteObject()
}
}
@@ -47,14 +58,14 @@ private[spark] class UnionPartition[T: ClassTag](idx: Int, rdd: RDD[T], splitInd
@DeveloperApi
class UnionRDD[T: ClassTag](
sc: SparkContext,
- @transient var rdds: Seq[RDD[T]])
+ var rdds: Seq[RDD[T]])
extends RDD[T](sc, Nil) { // Nil since we implement getDependencies
override def getPartitions: Array[Partition] = {
val array = new Array[Partition](rdds.map(_.partitions.size).sum)
var pos = 0
- for (rdd <- rdds; split <- rdd.partitions) {
- array(pos) = new UnionPartition(pos, rdd, split.index)
+ for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
+ array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
pos += 1
}
array
@@ -70,9 +81,17 @@ class UnionRDD[T: ClassTag](
deps
}
- override def compute(s: Partition, context: TaskContext): Iterator[T] =
- s.asInstanceOf[UnionPartition[T]].iterator(context)
+ override def compute(s: Partition, context: TaskContext): Iterator[T] = {
+ val part = s.asInstanceOf[UnionPartition[T]]
+ val parentRdd = dependencies(part.parentRddIndex).rdd.asInstanceOf[RDD[T]]
+ parentRdd.iterator(part.parentPartition, context)
+ }
override def getPreferredLocations(s: Partition): Seq[String] =
s.asInstanceOf[UnionPartition[T]].preferredLocations()
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdds = null
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 8966eedd80..ae6e525875 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -121,6 +121,18 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(union.partitioner === nums1.partitioner)
}
+ test("UnionRDD partition serialized size should be small") {
+ val largeVariable = new Array[Byte](1000 * 1000)
+ val rdd1 = sc.parallelize(1 to 10, 2).map(i => largeVariable.length)
+ val rdd2 = sc.parallelize(1 to 10, 3)
+
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+ val union = rdd1.union(rdd2)
+ // The UnionRDD itself should be large, but each individual partition should be small.
+ assert(ser.serialize(union).limit() > 2000)
+ assert(ser.serialize(union.partitions.head).limit() < 2000)
+ }
+
test("aggregate") {
val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3)))
type StringMap = HashMap[String, Int]