diff options
author | Matei Zaharia <matei@eecs.berkeley.edu> | 2010-10-16 17:13:52 -0700 |
---|---|---|
committer | Matei Zaharia <matei@eecs.berkeley.edu> | 2010-10-16 17:13:52 -0700 |
commit | 0e2adecdabe4b1b8f5df21ec16cea182b72d5626 (patch) | |
tree | c893defed49a51aff9804f85d439b32c04187c0f | |
parent | 166d9f91258d1e803ac44c9faf429a3920728ec6 (diff) | |
download | spark-0e2adecdabe4b1b8f5df21ec16cea182b72d5626.tar.gz spark-0e2adecdabe4b1b8f5df21ec16cea182b72d5626.tar.bz2 spark-0e2adecdabe4b1b8f5df21ec16cea182b72d5626.zip |
Simplified UnionRDD slightly and added a SparkContext.union method for efficiently union-ing a large number of RDDs
-rw-r--r-- | src/scala/spark/RDD.scala | 40 | ||||
-rw-r--r-- | src/scala/spark/SparkContext.scala | 10 |
2 files changed, 22 insertions, 28 deletions
diff --git a/src/scala/spark/RDD.scala b/src/scala/spark/RDD.scala index 803c063865..20865d7d28 100644 --- a/src/scala/spark/RDD.scala +++ b/src/scala/spark/RDD.scala @@ -82,11 +82,10 @@ abstract class RDD[T: ClassManifest]( try { map(x => 1L).reduce(_+_) } catch { case e: UnsupportedOperationException => 0L } - def union(other: RDD[T]) = new UnionRDD(sc, this, other) + def union(other: RDD[T]) = new UnionRDD(sc, Array(this, other)) def cartesian[U: ClassManifest](other: RDD[U]) = new CartesianRDD(sc, this, other) def ++(other: RDD[T]) = this.union(other) - } @serializable @@ -268,36 +267,27 @@ private object CachedRDD { } @serializable -abstract class UnionSplit[T: ClassManifest] extends Split { - def iterator(): Iterator[T] - def preferredLocations(): Seq[String] - def getId(): String -} - -@serializable -class UnionSplitImpl[T: ClassManifest]( - rdd: RDD[T], split: Split) -extends UnionSplit[T] { - override def iterator() = rdd.iterator(split) - override def preferredLocations() = rdd.preferredLocations(split) - override def getId() = - "UnionSplitImpl(" + split.getId() + ")" +class UnionSplit[T: ClassManifest](rdd: RDD[T], split: Split) +extends Split { + def iterator() = rdd.iterator(split) + def preferredLocations() = rdd.preferredLocations(split) + override def getId() = "UnionSplit(" + split.getId() + ")" } @serializable -class UnionRDD[T: ClassManifest]( - sc: SparkContext, rdd1: RDD[T], rdd2: RDD[T]) +class UnionRDD[T: ClassManifest](sc: SparkContext, rdds: Seq[RDD[T]]) extends RDD[T](sc) { - - @transient val splits_ : Array[UnionSplit[T]] = { - val a1 = rdd1.splits.map(s => new UnionSplitImpl(rdd1, s)) - val a2 = rdd2.splits.map(s => new UnionSplitImpl(rdd2, s)) - (a1 ++ a2).toArray + @transient val splits_ : Array[Split] = { + val splits: Seq[Split] = + for (rdd <- rdds; split <- rdd.splits) + yield new UnionSplit(rdd, split) + splits.toArray } - override def splits = splits_.asInstanceOf[Array[Split]] + override def splits = splits_ - override def iterator(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator() + override def iterator(s: Split): Iterator[T] = + s.asInstanceOf[UnionSplit[T]].iterator() override def preferredLocations(s: Split): Seq[String] = s.asInstanceOf[UnionSplit[T]].preferredLocations() diff --git a/src/scala/spark/SparkContext.scala b/src/scala/spark/SparkContext.scala index 69c3332bb0..953eac9eba 100644 --- a/src/scala/spark/SparkContext.scala +++ b/src/scala/spark/SparkContext.scala @@ -33,13 +33,17 @@ extends Logging { // Methods for creating RDDs - def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int) = + def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int): RDD[T] = new ParallelArray[T](this, seq, numSlices) - def parallelize[T: ClassManifest](seq: Seq[T]): ParallelArray[T] = + def parallelize[T: ClassManifest](seq: Seq[T]): RDD[T] = parallelize(seq, scheduler.numCores) - def textFile(path: String) = new HdfsTextFile(this, path) + def textFile(path: String): RDD[String] = + new HdfsTextFile(this, path) + + def union[T: ClassManifest](rdds: RDD[T]*): RDD[T] = + new UnionRDD(this, rdds) // Methods for creating shared variables |