aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/rdd/CartesianRDD.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/spark/rdd/CartesianRDD.scala')
-rw-r--r--core/src/main/scala/spark/rdd/CartesianRDD.scala54
1 files changed, 54 insertions, 0 deletions
diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala
new file mode 100644
index 0000000000..7c354b6b2e
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala
@@ -0,0 +1,54 @@
+package spark.rdd
+
+import spark.NarrowDependency
+import spark.RDD
+import spark.SparkContext
+import spark.Split
+
+private[spark]
+class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable {
+ override val index: Int = idx
+}
+
+private[spark]
+class CartesianRDD[T: ClassManifest, U:ClassManifest](
+ sc: SparkContext,
+ rdd1: RDD[T],
+ rdd2: RDD[U])
+ extends RDD[Pair[T, U]](sc)
+ with Serializable {
+
+ val numSplitsInRdd2 = rdd2.splits.size
+
+ @transient
+ val splits_ = {
+ // create the cross product split
+ val array = new Array[Split](rdd1.splits.size * rdd2.splits.size)
+ for (s1 <- rdd1.splits; s2 <- rdd2.splits) {
+ val idx = s1.index * numSplitsInRdd2 + s2.index
+ array(idx) = new CartesianSplit(idx, s1, s2)
+ }
+ array
+ }
+
+ override def splits = splits_
+
+ override def preferredLocations(split: Split) = {
+ val currSplit = split.asInstanceOf[CartesianSplit]
+ rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
+ }
+
+ override def compute(split: Split) = {
+ val currSplit = split.asInstanceOf[CartesianSplit]
+ for (x <- rdd1.iterator(currSplit.s1); y <- rdd2.iterator(currSplit.s2)) yield (x, y)
+ }
+
+ override val dependencies = List(
+ new NarrowDependency(rdd1) {
+ def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2)
+ },
+ new NarrowDependency(rdd2) {
+ def getParents(id: Int): Seq[Int] = List(id % numSplitsInRdd2)
+ }
+ )
+}