aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2010-10-16 17:13:52 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2010-10-16 17:13:52 -0700
commit0e2adecdabe4b1b8f5df21ec16cea182b72d5626 (patch)
treec893defed49a51aff9804f85d439b32c04187c0f
parent166d9f91258d1e803ac44c9faf429a3920728ec6 (diff)
downloadspark-0e2adecdabe4b1b8f5df21ec16cea182b72d5626.tar.gz
spark-0e2adecdabe4b1b8f5df21ec16cea182b72d5626.tar.bz2
spark-0e2adecdabe4b1b8f5df21ec16cea182b72d5626.zip
Simplified UnionRDD slightly and added a SparkContext.union method for efficiently union-ing a large number of RDDs
-rw-r--r--src/scala/spark/RDD.scala40
-rw-r--r--src/scala/spark/SparkContext.scala10
2 files changed, 22 insertions, 28 deletions
diff --git a/src/scala/spark/RDD.scala b/src/scala/spark/RDD.scala
index 803c063865..20865d7d28 100644
--- a/src/scala/spark/RDD.scala
+++ b/src/scala/spark/RDD.scala
@@ -82,11 +82,10 @@ abstract class RDD[T: ClassManifest](
try { map(x => 1L).reduce(_+_) }
catch { case e: UnsupportedOperationException => 0L }
- def union(other: RDD[T]) = new UnionRDD(sc, this, other)
+ def union(other: RDD[T]) = new UnionRDD(sc, Array(this, other))
def cartesian[U: ClassManifest](other: RDD[U]) = new CartesianRDD(sc, this, other)
def ++(other: RDD[T]) = this.union(other)
-
}
@serializable
@@ -268,36 +267,27 @@ private object CachedRDD {
}
@serializable
-abstract class UnionSplit[T: ClassManifest] extends Split {
- def iterator(): Iterator[T]
- def preferredLocations(): Seq[String]
- def getId(): String
-}
-
-@serializable
-class UnionSplitImpl[T: ClassManifest](
- rdd: RDD[T], split: Split)
-extends UnionSplit[T] {
- override def iterator() = rdd.iterator(split)
- override def preferredLocations() = rdd.preferredLocations(split)
- override def getId() =
- "UnionSplitImpl(" + split.getId() + ")"
+class UnionSplit[T: ClassManifest](rdd: RDD[T], split: Split)
+extends Split {
+ def iterator() = rdd.iterator(split)
+ def preferredLocations() = rdd.preferredLocations(split)
+ override def getId() = "UnionSplit(" + split.getId() + ")"
}
@serializable
-class UnionRDD[T: ClassManifest](
- sc: SparkContext, rdd1: RDD[T], rdd2: RDD[T])
+class UnionRDD[T: ClassManifest](sc: SparkContext, rdds: Seq[RDD[T]])
extends RDD[T](sc) {
-
- @transient val splits_ : Array[UnionSplit[T]] = {
- val a1 = rdd1.splits.map(s => new UnionSplitImpl(rdd1, s))
- val a2 = rdd2.splits.map(s => new UnionSplitImpl(rdd2, s))
- (a1 ++ a2).toArray
+ @transient val splits_ : Array[Split] = {
+ val splits: Seq[Split] =
+ for (rdd <- rdds; split <- rdd.splits)
+ yield new UnionSplit(rdd, split)
+ splits.toArray
}
- override def splits = splits_.asInstanceOf[Array[Split]]
+ override def splits = splits_
- override def iterator(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator()
+ override def iterator(s: Split): Iterator[T] =
+ s.asInstanceOf[UnionSplit[T]].iterator()
override def preferredLocations(s: Split): Seq[String] =
s.asInstanceOf[UnionSplit[T]].preferredLocations()
diff --git a/src/scala/spark/SparkContext.scala b/src/scala/spark/SparkContext.scala
index 69c3332bb0..953eac9eba 100644
--- a/src/scala/spark/SparkContext.scala
+++ b/src/scala/spark/SparkContext.scala
@@ -33,13 +33,17 @@ extends Logging {
// Methods for creating RDDs
- def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int) =
+ def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int): RDD[T] =
new ParallelArray[T](this, seq, numSlices)
- def parallelize[T: ClassManifest](seq: Seq[T]): ParallelArray[T] =
+ def parallelize[T: ClassManifest](seq: Seq[T]): RDD[T] =
parallelize(seq, scheduler.numCores)
- def textFile(path: String) = new HdfsTextFile(this, path)
+ def textFile(path: String): RDD[String] =
+ new HdfsTextFile(this, path)
+
+ def union[T: ClassManifest](rdds: RDD[T]*): RDD[T] =
+ new UnionRDD(this, rdds)
// Methods for creating shared variables