aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/rdd/CartesianRDD.scala
blob: 4a7e5f3d0602887a4b9f26dc0db170a9f038d40e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
package spark.rdd

import spark.{NarrowDependency, RDD, SparkContext, Split, TaskContext}


private[spark]
class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable {
  override val index: Int = idx
}

private[spark]
class CartesianRDD[T: ClassManifest, U:ClassManifest](
    sc: SparkContext,
    rdd1: RDD[T],
    rdd2: RDD[U])
  extends RDD[Pair[T, U]](sc)
  with Serializable {

  val numSplitsInRdd2 = rdd2.splits.size

  @transient
  val splits_ = {
    // create the cross product split
    val array = new Array[Split](rdd1.splits.size * rdd2.splits.size)
    for (s1 <- rdd1.splits; s2 <- rdd2.splits) {
      val idx = s1.index * numSplitsInRdd2 + s2.index
      array(idx) = new CartesianSplit(idx, s1, s2)
    }
    array
  }

  override def splits = splits_

  override def preferredLocations(split: Split) = {
    val currSplit = split.asInstanceOf[CartesianSplit]
    rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
  }

  override def compute(split: Split, context: TaskContext) = {
    val currSplit = split.asInstanceOf[CartesianSplit]
    for (x <- rdd1.iterator(currSplit.s1, context);
      y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
  }

  override val dependencies = List(
    new NarrowDependency(rdd1) {
      def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2)
    },
    new NarrowDependency(rdd2) {
      def getParents(id: Int): Seq[Int] = List(id % numSplitsInRdd2)
    }
  )
}