aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2012-11-27 22:27:47 -0800
committerMatei Zaharia <matei@eecs.berkeley.edu>2012-11-27 22:27:47 -0800
commit27e43abd192440de5b10a5cc022fd5705362b276 (patch)
treefe3ac7899babc6f5c26f0dc90607820c9804e511
parent59c0a9ad164ef8a6382737aa197f41e407e1c89d (diff)
downloadspark-27e43abd192440de5b10a5cc022fd5705362b276.tar.gz
spark-27e43abd192440de5b10a5cc022fd5705362b276.tar.bz2
spark-27e43abd192440de5b10a5cc022fd5705362b276.zip
Added a zip() operation for RDDs with the same shape (number of
partitions and number of elements in each partition)
-rw-r--r--core/src/main/scala/spark/RDD.scala9
-rw-r--r--core/src/main/scala/spark/rdd/ZippedRDD.scala54
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala12
3 files changed, 75 insertions, 0 deletions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 338dff4061..f4288a9661 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -42,6 +42,7 @@ import spark.rdd.MapPartitionsWithSplitRDD
import spark.rdd.PipedRDD
import spark.rdd.SampledRDD
import spark.rdd.UnionRDD
+import spark.rdd.ZippedRDD
import spark.storage.StorageLevel
import SparkContext._
@@ -293,6 +294,14 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] =
new MapPartitionsWithSplitRDD(this, sc.clean(f))
+ /**
+ * Zips this RDD with another one, returning key-value pairs with the first element in each RDD,
+ * second element in each RDD, etc. Assumes that the two RDDs have the *same number of
+ * partitions* and the *same number of elements in each partition* (e.g. one was made through
+ * a map on the other).
+ */
+ def zip[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other)
+
// Actions (launch a job to return a value to the user program)
/**
diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala
new file mode 100644
index 0000000000..80f0150c45
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala
@@ -0,0 +1,54 @@
+package spark.rdd
+
+import spark.Dependency
+import spark.OneToOneDependency
+import spark.RDD
+import spark.SparkContext
+import spark.Split
+
+private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest](
+ idx: Int,
+ rdd1: RDD[T],
+ rdd2: RDD[U],
+ split1: Split,
+ split2: Split)
+ extends Split
+ with Serializable {
+
+ def iterator(): Iterator[(T, U)] = rdd1.iterator(split1).zip(rdd2.iterator(split2))
+
+ def preferredLocations(): Seq[String] =
+ rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2))
+
+ override val index: Int = idx
+}
+
+class ZippedRDD[T: ClassManifest, U: ClassManifest](
+ sc: SparkContext,
+ @transient rdd1: RDD[T],
+ @transient rdd2: RDD[U])
+ extends RDD[(T, U)](sc)
+ with Serializable {
+
+ @transient
+ val splits_ : Array[Split] = {
+ if (rdd1.splits.size != rdd2.splits.size) {
+ throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
+ }
+ val array = new Array[Split](rdd1.splits.size)
+ for (i <- 0 until rdd1.splits.size) {
+ array(i) = new ZippedSplit(i, rdd1, rdd2, rdd1.splits(i), rdd2.splits(i))
+ }
+ array
+ }
+
+ override def splits = splits_
+
+ @transient
+ override val dependencies = List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))
+
+ override def compute(s: Split): Iterator[(T, U)] = s.asInstanceOf[ZippedSplit[T, U]].iterator()
+
+ override def preferredLocations(s: Split): Seq[String] =
+ s.asInstanceOf[ZippedSplit[T, U]].preferredLocations()
+}
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index 37a0ff0947..b3c820ed94 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -114,4 +114,16 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
assert(coalesced4.glom().collect().map(_.toList).toList ===
(1 to 10).map(x => List(x)).toList)
}
+
+ test("zipped RDDs") {
+ sc = new SparkContext("local", "test")
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val zipped = nums.zip(nums.map(_ + 1.0))
+ assert(zipped.glom().map(_.toList).collect().toList ===
+ List(List((1, 2.0), (2, 3.0)), List((3, 4.0), (4, 5.0))))
+
+ intercept[IllegalArgumentException] {
+ nums.zip(sc.parallelize(1 to 4, 1)).collect()
+ }
+ }
}