diff options
author | Matei Zaharia <matei@eecs.berkeley.edu> | 2011-07-14 14:47:12 -0400 |
---|---|---|
committer | Matei Zaharia <matei@eecs.berkeley.edu> | 2011-07-14 14:47:12 -0400 |
commit | 8ea67307b9855074dfaa7695eac1c280dec05027 (patch) | |
tree | f34768c6764a24409f8e47b6e80a7bcef7b4ab42 | |
parent | 3efd9e94d86ee63869a73f5fd716c743353135ca (diff) | |
parent | e4c3402d2d174c50a756d0ac6b6f888363f8d7c9 (diff) | |
download | spark-8ea67307b9855074dfaa7695eac1c280dec05027.tar.gz spark-8ea67307b9855074dfaa7695eac1c280dec05027.tar.bz2 spark-8ea67307b9855074dfaa7695eac1c280dec05027.zip |
Merge branch 'master' into scala-2.9
-rw-r--r-- | core/src/main/scala/spark/Accumulators.scala | 32 | ||||
-rw-r--r-- | core/src/main/scala/spark/DAGScheduler.scala | 2 | ||||
-rw-r--r-- | core/src/main/scala/spark/ParallelCollection.scala (renamed from core/src/main/scala/spark/ParallelArray.scala) | 23 | ||||
-rw-r--r-- | core/src/main/scala/spark/RDD.scala | 4 | ||||
-rw-r--r-- | core/src/main/scala/spark/SparkContext.scala | 2 | ||||
-rw-r--r-- | core/src/test/scala/spark/ParallelCollectionSplitSuite.scala (renamed from core/src/test/scala/spark/ParallelArraySplitSuite.scala) | 32 |
6 files changed, 48 insertions, 47 deletions
diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index ee93d3c85c..4f51826d9d 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -12,7 +12,7 @@ import scala.collection.mutable.Map val zero = param.zero(initialValue) // Zero value to be passed to workers var deserialized = false - Accumulators.register(this) + Accumulators.register(this, true) def += (term: T) { value_ = param.addInPlace(value_, term) } def value = this.value_ @@ -26,7 +26,7 @@ import scala.collection.mutable.Map in.defaultReadObject value_ = zero deserialized = true - Accumulators.register(this) + Accumulators.register(this, false) } override def toString = value_.toString @@ -42,31 +42,39 @@ import scala.collection.mutable.Map private object Accumulators { // TODO: Use soft references? => need to make readObject work properly then - val accums = Map[(Thread, Long), Accumulator[_]]() - var lastId: Long = 0 + val originals = Map[Long, Accumulator[_]]() + val localAccums = Map[Thread, Map[Long, Accumulator[_]]]() + var lastId: Long = 0 def newId: Long = synchronized { lastId += 1; return lastId } - def register(a: Accumulator[_]): Unit = synchronized { - accums((currentThread, a.id)) = a + def register(a: Accumulator[_], original: Boolean): Unit = synchronized { + if (original) { + originals(a.id) = a + } else { + val accums = localAccums.getOrElseUpdate(currentThread, Map()) + accums(a.id) = a + } } + // Clear the local (non-original) accumulators for the current thread def clear: Unit = synchronized { - accums.retain((key, accum) => key._1 != currentThread) + localAccums.remove(currentThread) } + // Get the values of the local accumulators for the current thread (by ID) def values: Map[Long, Any] = synchronized { val ret = Map[Long, Any]() - for(((thread, id), accum) <- accums if thread == currentThread) + for ((id, accum) <- localAccums.getOrElse(currentThread, Map())) ret(id) = accum.value return ret } - def add(thread: Thread, values: Map[Long, Any]): Unit = synchronized { + // Add values to the original accumulators with some given IDs + def add(values: Map[Long, Any]): Unit = synchronized { for ((id, value) <- values) { - if (accums.contains((thread, id))) { - val accum = accums((thread, id)) - accum.asInstanceOf[Accumulator[Any]] += value + if (originals.contains(id)) { + originals(id).asInstanceOf[Accumulator[Any]] += value } } } diff --git a/core/src/main/scala/spark/DAGScheduler.scala b/core/src/main/scala/spark/DAGScheduler.scala index 42bb3c2a75..93cab9fb62 100644 --- a/core/src/main/scala/spark/DAGScheduler.scala +++ b/core/src/main/scala/spark/DAGScheduler.scala @@ -225,7 +225,7 @@ private trait DAGScheduler extends Scheduler with Logging { if (evt.reason == Success) { // A task ended logInfo("Completed " + evt.task) - Accumulators.add(currentThread, evt.accumUpdates) + Accumulators.add(evt.accumUpdates) evt.task match { case rt: ResultTask[_, _] => results(rt.outputId) = evt.result.asInstanceOf[U] diff --git a/core/src/main/scala/spark/ParallelArray.scala b/core/src/main/scala/spark/ParallelCollection.scala index e77bc3014f..36121766f5 100644 --- a/core/src/main/scala/spark/ParallelArray.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -4,23 +4,23 @@ import mesos.SlaveOffer import java.util.concurrent.atomic.AtomicLong -@serializable class ParallelArraySplit[T: ClassManifest]( - val arrayId: Long, val slice: Int, values: Seq[T]) +@serializable class ParallelCollectionSplit[T: ClassManifest]( + val rddId: Long, val slice: Int, values: Seq[T]) extends Split { def iterator(): Iterator[T] = values.iterator - override def hashCode(): Int = (41 * (41 + arrayId) + slice).toInt + override def hashCode(): Int = (41 * (41 + rddId) + slice).toInt override def equals(other: Any): Boolean = other match { - case that: ParallelArraySplit[_] => - (this.arrayId == that.arrayId && this.slice == that.slice) + case that: ParallelCollectionSplit[_] => + (this.rddId == that.rddId && this.slice == that.slice) case _ => false } override val index = slice } -class ParallelArray[T: ClassManifest]( +class ParallelCollection[T: ClassManifest]( sc: SparkContext, @transient data: Seq[T], numSlices: Int) extends RDD[T](sc) { // TODO: Right now, each split sends along its full data, even if later down @@ -28,23 +28,20 @@ extends RDD[T](sc) { // a file in the DFS and read it in the split instead. @transient val splits_ = { - val slices = ParallelArray.slice(data, numSlices).toArray - slices.indices.map(i => new ParallelArraySplit(id, i, slices(i))).toArray + val slices = ParallelCollection.slice(data, numSlices).toArray + slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray } override def splits = splits_.asInstanceOf[Array[Split]] - override def compute(s: Split) = s.asInstanceOf[ParallelArraySplit[T]].iterator + override def compute(s: Split) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator override def preferredLocations(s: Split): Seq[String] = Nil override val dependencies: List[Dependency[_]] = Nil } -private object ParallelArray { - val nextId = new AtomicLong(0) // Creates IDs for ParallelArrays (on master) - def newId() = nextId.getAndIncrement() - +private object ParallelCollection { def slice[T: ClassManifest](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = { if (numSlices < 1) throw new IllegalArgumentException("Positive number of slices required") diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index b8340d3f11..a0c4e29771 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -161,10 +161,6 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) { def toArray(): Array[T] = collect() - override def toString(): String = { - "%s(%d)".format(getClass.getSimpleName, id) - } - // Take the first num elements of the RDD. This currently scans the partitions // *one by one*, so it will be slow if a lot of partitions are required. In that // case, use collect() to get the whole RDD instead. diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 0a866ce198..1f2bddc60e 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -64,7 +64,7 @@ extends Logging { // Methods for creating RDDs def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = - new ParallelArray[T](this, seq, numSlices) + new ParallelCollection[T](this, seq, numSlices) def makeRDD[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = parallelize(seq, numSlices) diff --git a/core/src/test/scala/spark/ParallelArraySplitSuite.scala b/core/src/test/scala/spark/ParallelCollectionSplitSuite.scala index 222df4e071..af6ec8bae5 100644 --- a/core/src/test/scala/spark/ParallelArraySplitSuite.scala +++ b/core/src/test/scala/spark/ParallelCollectionSplitSuite.scala @@ -6,10 +6,10 @@ import org.scalacheck.Arbitrary._ import org.scalacheck.Gen import org.scalacheck.Prop._ -class ParallelArraySplitSuite extends FunSuite with Checkers { +class ParallelCollectionSplitSuite extends FunSuite with Checkers { test("one element per slice") { val data = Array(1, 2, 3) - val slices = ParallelArray.slice(data, 3) + val slices = ParallelCollection.slice(data, 3) assert(slices.size === 3) assert(slices(0).mkString(",") === "1") assert(slices(1).mkString(",") === "2") @@ -18,14 +18,14 @@ class ParallelArraySplitSuite extends FunSuite with Checkers { test("one slice") { val data = Array(1, 2, 3) - val slices = ParallelArray.slice(data, 1) + val slices = ParallelCollection.slice(data, 1) assert(slices.size === 1) assert(slices(0).mkString(",") === "1,2,3") } test("equal slices") { val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9) - val slices = ParallelArray.slice(data, 3) + val slices = ParallelCollection.slice(data, 3) assert(slices.size === 3) assert(slices(0).mkString(",") === "1,2,3") assert(slices(1).mkString(",") === "4,5,6") @@ -34,7 +34,7 @@ class ParallelArraySplitSuite extends FunSuite with Checkers { test("non-equal slices") { val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - val slices = ParallelArray.slice(data, 3) + val slices = ParallelCollection.slice(data, 3) assert(slices.size === 3) assert(slices(0).mkString(",") === "1,2,3") assert(slices(1).mkString(",") === "4,5,6") @@ -43,7 +43,7 @@ class ParallelArraySplitSuite extends FunSuite with Checkers { test("splitting exclusive range") { val data = 0 until 100 - val slices = ParallelArray.slice(data, 3) + val slices = ParallelCollection.slice(data, 3) assert(slices.size === 3) assert(slices(0).mkString(",") === (0 to 32).mkString(",")) assert(slices(1).mkString(",") === (33 to 65).mkString(",")) @@ -52,7 +52,7 @@ class ParallelArraySplitSuite extends FunSuite with Checkers { test("splitting inclusive range") { val data = 0 to 100 - val slices = ParallelArray.slice(data, 3) + val slices = ParallelCollection.slice(data, 3) assert(slices.size === 3) assert(slices(0).mkString(",") === (0 to 32).mkString(",")) assert(slices(1).mkString(",") === (33 to 66).mkString(",")) @@ -61,24 +61,24 @@ class ParallelArraySplitSuite extends FunSuite with Checkers { test("empty data") { val data = new Array[Int](0) - val slices = ParallelArray.slice(data, 5) + val slices = ParallelCollection.slice(data, 5) assert(slices.size === 5) for (slice <- slices) assert(slice.size === 0) } test("zero slices") { val data = Array(1, 2, 3) - intercept[IllegalArgumentException] { ParallelArray.slice(data, 0) } + intercept[IllegalArgumentException] { ParallelCollection.slice(data, 0) } } test("negative number of slices") { val data = Array(1, 2, 3) - intercept[IllegalArgumentException] { ParallelArray.slice(data, -5) } + intercept[IllegalArgumentException] { ParallelCollection.slice(data, -5) } } test("exclusive ranges sliced into ranges") { val data = 1 until 100 - val slices = ParallelArray.slice(data, 3) + val slices = ParallelCollection.slice(data, 3) assert(slices.size === 3) assert(slices.map(_.size).reduceLeft(_+_) === 99) assert(slices.forall(_.isInstanceOf[Range])) @@ -86,7 +86,7 @@ class ParallelArraySplitSuite extends FunSuite with Checkers { test("inclusive ranges sliced into ranges") { val data = 1 to 100 - val slices = ParallelArray.slice(data, 3) + val slices = ParallelCollection.slice(data, 3) assert(slices.size === 3) assert(slices.map(_.size).reduceLeft(_+_) === 100) assert(slices.forall(_.isInstanceOf[Range])) @@ -95,7 +95,7 @@ class ParallelArraySplitSuite extends FunSuite with Checkers { test("large ranges don't overflow") { val N = 100 * 1000 * 1000 val data = 0 until N - val slices = ParallelArray.slice(data, 40) + val slices = ParallelCollection.slice(data, 40) assert(slices.size === 40) for (i <- 0 until 40) { assert(slices(i).isInstanceOf[Range]) @@ -115,7 +115,7 @@ class ParallelArraySplitSuite extends FunSuite with Checkers { (tuple: (List[Int], Int)) => val d = tuple._1 val n = tuple._2 - val slices = ParallelArray.slice(d, n) + val slices = ParallelCollection.slice(d, n) ("n slices" |: slices.size == n) && ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) @@ -132,7 +132,7 @@ class ParallelArraySplitSuite extends FunSuite with Checkers { } yield (a until b by step, n) val prop = forAll(gen) { case (d: Range, n: Int) => - val slices = ParallelArray.slice(d, n) + val slices = ParallelCollection.slice(d, n) ("n slices" |: slices.size == n) && ("all ranges" |: slices.forall(_.isInstanceOf[Range])) && ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && @@ -150,7 +150,7 @@ class ParallelArraySplitSuite extends FunSuite with Checkers { } yield (a to b by step, n) val prop = forAll(gen) { case (d: Range, n: Int) => - val slices = ParallelArray.slice(d, n) + val slices = ParallelCollection.slice(d, n) ("n slices" |: slices.size == n) && ("all ranges" |: slices.forall(_.isInstanceOf[Range])) && ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && |