aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/UnionRDD.scala
blob: dadfd94eefdb5c9c931076152f5f817055c1f5d0 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
package spark

import scala.collection.mutable.ArrayBuffer

class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], split: Split)
extends Split with Serializable {
  def iterator() = rdd.iterator(split)
  def preferredLocations() = rdd.preferredLocations(split)
  override val index = idx
}

class UnionRDD[T: ClassManifest](sc: SparkContext, rdds: Seq[RDD[T]])
extends RDD[T](sc) with Serializable {
  @transient val splits_ : Array[Split] = {
    val array = new Array[Split](rdds.map(_.splits.size).sum)
    var pos = 0
    for (rdd <- rdds; split <- rdd.splits) {
      array(pos) = new UnionSplit(pos, rdd, split)
      pos += 1
    }
    array
  }

  override def splits = splits_

  override val dependencies = {
    val deps = new ArrayBuffer[Dependency[_]]
    var pos = 0
    for ((rdd, index) <- rdds.zipWithIndex) {
      deps += new RangeDependency(rdd, 0, pos, rdd.splits.size) 
      pos += rdd.splits.size
    }
    deps.toList
  }
  
  override def compute(s: Split): Iterator[T] =
    s.asInstanceOf[UnionSplit[T]].iterator()

  override def preferredLocations(s: Split): Seq[String] =
    s.asInstanceOf[UnionSplit[T]].preferredLocations()
}