aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorImran Rashid <imran@quantifind.com>2012-07-26 12:38:51 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2012-07-28 20:16:17 -0700
commitf7149c5e46281498dad4644bdc468e63ad6da667 (patch)
treebffa0b99bbb24e9d2f20dfa10f5ae5f39b2697f7
parent244cbbe33a3f1e2566cde322eef2a02a11d35096 (diff)
downloadspark-f7149c5e46281498dad4644bdc468e63ad6da667.tar.gz
spark-f7149c5e46281498dad4644bdc468e63ad6da667.tar.bz2
spark-f7149c5e46281498dad4644bdc468e63ad6da667.zip
tasks cannot access value of accumulator
-rw-r--r--core/src/main/scala/spark/Accumulators.scala12
-rw-r--r--core/src/test/scala/spark/AccumulatorSuite.scala65
2 files changed, 21 insertions, 56 deletions
diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala
index e63651fcb0..a155adaa87 100644
--- a/core/src/main/scala/spark/Accumulators.scala
+++ b/core/src/main/scala/spark/Accumulators.scala
@@ -11,7 +11,7 @@ class Accumulable[T,R] (
val id = Accumulators.newId
@transient
- var value_ = initialValue // Current value on master
+ private var value_ = initialValue // Current value on master
val zero = param.zero(initialValue) // Zero value to be passed to workers
var deserialized = false
@@ -30,7 +30,13 @@ class Accumulable[T,R] (
* @param term the other Accumulable that will get merged with this
*/
def ++= (term: T) { value_ = param.addInPlace(value_, term)}
- def value = this.value_
+ def value = {
+ if (!deserialized) value_
+ else throw new UnsupportedOperationException("Can't use read value in task")
+ }
+
+ private[spark] def localValue = value_
+
def value_= (t: T) {
if (!deserialized) value_ = t
else throw new UnsupportedOperationException("Can't use value_= in task")
@@ -126,7 +132,7 @@ private object Accumulators {
def values: Map[Long, Any] = synchronized {
val ret = Map[Long, Any]()
for ((id, accum) <- localAccums.getOrElse(Thread.currentThread, Map())) {
- ret(id) = accum.value
+ ret(id) = accum.localValue
}
return ret
}
diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala
index d9ef8797d6..a59b77fc85 100644
--- a/core/src/test/scala/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/spark/AccumulatorSuite.scala
@@ -63,60 +63,19 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers {
}
- test ("value readable in tasks") {
- import spark.util.Vector
- //stochastic gradient descent with weights stored in accumulator -- should be able to read value as we go
-
- //really easy data
- val N = 10000 // Number of data points
- val D = 10 // Numer of dimensions
- val R = 0.7 // Scaling factor
- val ITERATIONS = 5
- val rand = new Random(42)
-
- case class DataPoint(x: Vector, y: Double)
-
- def generateData = {
- def generatePoint(i: Int) = {
- val y = if(i % 2 == 0) -1 else 1
- val goodX = Vector(D, _ => 0.0001 * rand.nextGaussian() + y)
- val noiseX = Vector(D, _ => rand.nextGaussian())
- val x = Vector((goodX.elements.toSeq ++ noiseX.elements.toSeq): _*)
- DataPoint(x, y)
- }
- Array.tabulate(N)(generatePoint)
- }
-
- val data = generateData
- for (nThreads <- List(1, 10)) {
- //test single & multi-threaded
- val sc = new SparkContext("local[" + nThreads + "]", "test")
- val weights = Vector.zeros(2*D)
- val weightDelta = sc.accumulator(Vector.zeros(2 * D))
- for (itr <- 1 to ITERATIONS) {
- val eta = 0.1 / itr
- val badErrs = sc.accumulator(0)
- sc.parallelize(data).foreach {
- p => {
- //XXX Note the call to .value here. That is required for this to be an online gradient descent
- // instead of a batch version. Should it change to .localValue, and should .value throw an error
- // if you try to do this??
- val prod = weightDelta.value.plusDot(weights, p.x)
- val trueClassProb = (1 / (1 + exp(-p.y * prod))) // works b/c p(-z) = 1 - p(z) (where p is the logistic function)
- val update = p.x * trueClassProb * p.y * eta
- //we could also include a momentum term here if our weightDelta accumulator saved a momentum
- weightDelta.value += update
- if (trueClassProb <= 0.95)
- badErrs += 1
- }
+ test ("value not readable in tasks") {
+ import SetAccum._
+ val maxI = 1000
+ for (nThreads <- List(1, 10)) { //test single & multi-threaded
+ val sc = new SparkContext("local[" + nThreads + "]", "test")
+ val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]())
+ val d = sc.parallelize(1 to maxI)
+ val thrown = evaluating {
+ d.foreach {
+ x => acc.value += x
}
- println("Iteration " + itr + " had badErrs = " + badErrs.value)
- weights += weightDelta.value
- println(weights)
- //TODO I should check the number of bad errors here, but for some reason spark tries to serialize the assertion ...
-// val assertVal = badErrs.value
-// assert (assertVal < 100)
- }
+ } should produce [SparkException]
+ println(thrown)
}
}