aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2011-07-14 14:47:12 -0400
committerMatei Zaharia <matei@eecs.berkeley.edu>2011-07-14 14:47:12 -0400
commit8ea67307b9855074dfaa7695eac1c280dec05027 (patch)
treef34768c6764a24409f8e47b6e80a7bcef7b4ab42 /core
parent3efd9e94d86ee63869a73f5fd716c743353135ca (diff)
parente4c3402d2d174c50a756d0ac6b6f888363f8d7c9 (diff)
downloadspark-8ea67307b9855074dfaa7695eac1c280dec05027.tar.gz
spark-8ea67307b9855074dfaa7695eac1c280dec05027.tar.bz2
spark-8ea67307b9855074dfaa7695eac1c280dec05027.zip
Merge branch 'master' into scala-2.9
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/Accumulators.scala32
-rw-r--r--core/src/main/scala/spark/DAGScheduler.scala2
-rw-r--r--core/src/main/scala/spark/ParallelCollection.scala (renamed from core/src/main/scala/spark/ParallelArray.scala)23
-rw-r--r--core/src/main/scala/spark/RDD.scala4
-rw-r--r--core/src/main/scala/spark/SparkContext.scala2
-rw-r--r--core/src/test/scala/spark/ParallelCollectionSplitSuite.scala (renamed from core/src/test/scala/spark/ParallelArraySplitSuite.scala)32
6 files changed, 48 insertions, 47 deletions
diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala
index ee93d3c85c..4f51826d9d 100644
--- a/core/src/main/scala/spark/Accumulators.scala
+++ b/core/src/main/scala/spark/Accumulators.scala
@@ -12,7 +12,7 @@ import scala.collection.mutable.Map
val zero = param.zero(initialValue) // Zero value to be passed to workers
var deserialized = false
- Accumulators.register(this)
+ Accumulators.register(this, true)
def += (term: T) { value_ = param.addInPlace(value_, term) }
def value = this.value_
@@ -26,7 +26,7 @@ import scala.collection.mutable.Map
in.defaultReadObject
value_ = zero
deserialized = true
- Accumulators.register(this)
+ Accumulators.register(this, false)
}
override def toString = value_.toString
@@ -42,31 +42,39 @@ import scala.collection.mutable.Map
private object Accumulators
{
// TODO: Use soft references? => need to make readObject work properly then
- val accums = Map[(Thread, Long), Accumulator[_]]()
- var lastId: Long = 0
+ val originals = Map[Long, Accumulator[_]]()
+ val localAccums = Map[Thread, Map[Long, Accumulator[_]]]()
+ var lastId: Long = 0
def newId: Long = synchronized { lastId += 1; return lastId }
- def register(a: Accumulator[_]): Unit = synchronized {
- accums((currentThread, a.id)) = a
+ def register(a: Accumulator[_], original: Boolean): Unit = synchronized {
+ if (original) {
+ originals(a.id) = a
+ } else {
+ val accums = localAccums.getOrElseUpdate(currentThread, Map())
+ accums(a.id) = a
+ }
}
+ // Clear the local (non-original) accumulators for the current thread
def clear: Unit = synchronized {
- accums.retain((key, accum) => key._1 != currentThread)
+ localAccums.remove(currentThread)
}
+ // Get the values of the local accumulators for the current thread (by ID)
def values: Map[Long, Any] = synchronized {
val ret = Map[Long, Any]()
- for(((thread, id), accum) <- accums if thread == currentThread)
+ for ((id, accum) <- localAccums.getOrElse(currentThread, Map()))
ret(id) = accum.value
return ret
}
- def add(thread: Thread, values: Map[Long, Any]): Unit = synchronized {
+ // Add values to the original accumulators with some given IDs
+ def add(values: Map[Long, Any]): Unit = synchronized {
for ((id, value) <- values) {
- if (accums.contains((thread, id))) {
- val accum = accums((thread, id))
- accum.asInstanceOf[Accumulator[Any]] += value
+ if (originals.contains(id)) {
+ originals(id).asInstanceOf[Accumulator[Any]] += value
}
}
}
diff --git a/core/src/main/scala/spark/DAGScheduler.scala b/core/src/main/scala/spark/DAGScheduler.scala
index 42bb3c2a75..93cab9fb62 100644
--- a/core/src/main/scala/spark/DAGScheduler.scala
+++ b/core/src/main/scala/spark/DAGScheduler.scala
@@ -225,7 +225,7 @@ private trait DAGScheduler extends Scheduler with Logging {
if (evt.reason == Success) {
// A task ended
logInfo("Completed " + evt.task)
- Accumulators.add(currentThread, evt.accumUpdates)
+ Accumulators.add(evt.accumUpdates)
evt.task match {
case rt: ResultTask[_, _] =>
results(rt.outputId) = evt.result.asInstanceOf[U]
diff --git a/core/src/main/scala/spark/ParallelArray.scala b/core/src/main/scala/spark/ParallelCollection.scala
index e77bc3014f..36121766f5 100644
--- a/core/src/main/scala/spark/ParallelArray.scala
+++ b/core/src/main/scala/spark/ParallelCollection.scala
@@ -4,23 +4,23 @@ import mesos.SlaveOffer
import java.util.concurrent.atomic.AtomicLong
-@serializable class ParallelArraySplit[T: ClassManifest](
- val arrayId: Long, val slice: Int, values: Seq[T])
+@serializable class ParallelCollectionSplit[T: ClassManifest](
+ val rddId: Long, val slice: Int, values: Seq[T])
extends Split {
def iterator(): Iterator[T] = values.iterator
- override def hashCode(): Int = (41 * (41 + arrayId) + slice).toInt
+ override def hashCode(): Int = (41 * (41 + rddId) + slice).toInt
override def equals(other: Any): Boolean = other match {
- case that: ParallelArraySplit[_] =>
- (this.arrayId == that.arrayId && this.slice == that.slice)
+ case that: ParallelCollectionSplit[_] =>
+ (this.rddId == that.rddId && this.slice == that.slice)
case _ => false
}
override val index = slice
}
-class ParallelArray[T: ClassManifest](
+class ParallelCollection[T: ClassManifest](
sc: SparkContext, @transient data: Seq[T], numSlices: Int)
extends RDD[T](sc) {
// TODO: Right now, each split sends along its full data, even if later down
@@ -28,23 +28,20 @@ extends RDD[T](sc) {
// a file in the DFS and read it in the split instead.
@transient val splits_ = {
- val slices = ParallelArray.slice(data, numSlices).toArray
- slices.indices.map(i => new ParallelArraySplit(id, i, slices(i))).toArray
+ val slices = ParallelCollection.slice(data, numSlices).toArray
+ slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray
}
override def splits = splits_.asInstanceOf[Array[Split]]
- override def compute(s: Split) = s.asInstanceOf[ParallelArraySplit[T]].iterator
+ override def compute(s: Split) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator
override def preferredLocations(s: Split): Seq[String] = Nil
override val dependencies: List[Dependency[_]] = Nil
}
-private object ParallelArray {
- val nextId = new AtomicLong(0) // Creates IDs for ParallelArrays (on master)
- def newId() = nextId.getAndIncrement()
-
+private object ParallelCollection {
def slice[T: ClassManifest](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
if (numSlices < 1)
throw new IllegalArgumentException("Positive number of slices required")
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index b8340d3f11..a0c4e29771 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -161,10 +161,6 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) {
def toArray(): Array[T] = collect()
- override def toString(): String = {
- "%s(%d)".format(getClass.getSimpleName, id)
- }
-
// Take the first num elements of the RDD. This currently scans the partitions
// *one by one*, so it will be slow if a lot of partitions are required. In that
// case, use collect() to get the whole RDD instead.
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 0a866ce198..1f2bddc60e 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -64,7 +64,7 @@ extends Logging {
// Methods for creating RDDs
def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] =
- new ParallelArray[T](this, seq, numSlices)
+ new ParallelCollection[T](this, seq, numSlices)
def makeRDD[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] =
parallelize(seq, numSlices)
diff --git a/core/src/test/scala/spark/ParallelArraySplitSuite.scala b/core/src/test/scala/spark/ParallelCollectionSplitSuite.scala
index 222df4e071..af6ec8bae5 100644
--- a/core/src/test/scala/spark/ParallelArraySplitSuite.scala
+++ b/core/src/test/scala/spark/ParallelCollectionSplitSuite.scala
@@ -6,10 +6,10 @@ import org.scalacheck.Arbitrary._
import org.scalacheck.Gen
import org.scalacheck.Prop._
-class ParallelArraySplitSuite extends FunSuite with Checkers {
+class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("one element per slice") {
val data = Array(1, 2, 3)
- val slices = ParallelArray.slice(data, 3)
+ val slices = ParallelCollection.slice(data, 3)
assert(slices.size === 3)
assert(slices(0).mkString(",") === "1")
assert(slices(1).mkString(",") === "2")
@@ -18,14 +18,14 @@ class ParallelArraySplitSuite extends FunSuite with Checkers {
test("one slice") {
val data = Array(1, 2, 3)
- val slices = ParallelArray.slice(data, 1)
+ val slices = ParallelCollection.slice(data, 1)
assert(slices.size === 1)
assert(slices(0).mkString(",") === "1,2,3")
}
test("equal slices") {
val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9)
- val slices = ParallelArray.slice(data, 3)
+ val slices = ParallelCollection.slice(data, 3)
assert(slices.size === 3)
assert(slices(0).mkString(",") === "1,2,3")
assert(slices(1).mkString(",") === "4,5,6")
@@ -34,7 +34,7 @@ class ParallelArraySplitSuite extends FunSuite with Checkers {
test("non-equal slices") {
val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
- val slices = ParallelArray.slice(data, 3)
+ val slices = ParallelCollection.slice(data, 3)
assert(slices.size === 3)
assert(slices(0).mkString(",") === "1,2,3")
assert(slices(1).mkString(",") === "4,5,6")
@@ -43,7 +43,7 @@ class ParallelArraySplitSuite extends FunSuite with Checkers {
test("splitting exclusive range") {
val data = 0 until 100
- val slices = ParallelArray.slice(data, 3)
+ val slices = ParallelCollection.slice(data, 3)
assert(slices.size === 3)
assert(slices(0).mkString(",") === (0 to 32).mkString(","))
assert(slices(1).mkString(",") === (33 to 65).mkString(","))
@@ -52,7 +52,7 @@ class ParallelArraySplitSuite extends FunSuite with Checkers {
test("splitting inclusive range") {
val data = 0 to 100
- val slices = ParallelArray.slice(data, 3)
+ val slices = ParallelCollection.slice(data, 3)
assert(slices.size === 3)
assert(slices(0).mkString(",") === (0 to 32).mkString(","))
assert(slices(1).mkString(",") === (33 to 66).mkString(","))
@@ -61,24 +61,24 @@ class ParallelArraySplitSuite extends FunSuite with Checkers {
test("empty data") {
val data = new Array[Int](0)
- val slices = ParallelArray.slice(data, 5)
+ val slices = ParallelCollection.slice(data, 5)
assert(slices.size === 5)
for (slice <- slices) assert(slice.size === 0)
}
test("zero slices") {
val data = Array(1, 2, 3)
- intercept[IllegalArgumentException] { ParallelArray.slice(data, 0) }
+ intercept[IllegalArgumentException] { ParallelCollection.slice(data, 0) }
}
test("negative number of slices") {
val data = Array(1, 2, 3)
- intercept[IllegalArgumentException] { ParallelArray.slice(data, -5) }
+ intercept[IllegalArgumentException] { ParallelCollection.slice(data, -5) }
}
test("exclusive ranges sliced into ranges") {
val data = 1 until 100
- val slices = ParallelArray.slice(data, 3)
+ val slices = ParallelCollection.slice(data, 3)
assert(slices.size === 3)
assert(slices.map(_.size).reduceLeft(_+_) === 99)
assert(slices.forall(_.isInstanceOf[Range]))
@@ -86,7 +86,7 @@ class ParallelArraySplitSuite extends FunSuite with Checkers {
test("inclusive ranges sliced into ranges") {
val data = 1 to 100
- val slices = ParallelArray.slice(data, 3)
+ val slices = ParallelCollection.slice(data, 3)
assert(slices.size === 3)
assert(slices.map(_.size).reduceLeft(_+_) === 100)
assert(slices.forall(_.isInstanceOf[Range]))
@@ -95,7 +95,7 @@ class ParallelArraySplitSuite extends FunSuite with Checkers {
test("large ranges don't overflow") {
val N = 100 * 1000 * 1000
val data = 0 until N
- val slices = ParallelArray.slice(data, 40)
+ val slices = ParallelCollection.slice(data, 40)
assert(slices.size === 40)
for (i <- 0 until 40) {
assert(slices(i).isInstanceOf[Range])
@@ -115,7 +115,7 @@ class ParallelArraySplitSuite extends FunSuite with Checkers {
(tuple: (List[Int], Int)) =>
val d = tuple._1
val n = tuple._2
- val slices = ParallelArray.slice(d, n)
+ val slices = ParallelCollection.slice(d, n)
("n slices" |: slices.size == n) &&
("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) &&
("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1))
@@ -132,7 +132,7 @@ class ParallelArraySplitSuite extends FunSuite with Checkers {
} yield (a until b by step, n)
val prop = forAll(gen) {
case (d: Range, n: Int) =>
- val slices = ParallelArray.slice(d, n)
+ val slices = ParallelCollection.slice(d, n)
("n slices" |: slices.size == n) &&
("all ranges" |: slices.forall(_.isInstanceOf[Range])) &&
("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) &&
@@ -150,7 +150,7 @@ class ParallelArraySplitSuite extends FunSuite with Checkers {
} yield (a to b by step, n)
val prop = forAll(gen) {
case (d: Range, n: Int) =>
- val slices = ParallelArray.slice(d, n)
+ val slices = ParallelCollection.slice(d, n)
("n slices" |: slices.size == n) &&
("all ranges" |: slices.forall(_.isInstanceOf[Range])) &&
("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) &&