aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWeichenXu <WeichenXu123@outlook.com>2016-05-18 11:48:46 +0100
committerSean Owen <sowen@cloudera.com>2016-05-18 11:48:46 +0100
commit2f9047b5eb969e0198b8a73e392642ca852ba786 (patch)
tree152fe58ada0fa73a5a5e151b4d0ce188c65be0b5
parent33814f887aea339c99e14ce7f14ca6fcc6875015 (diff)
downloadspark-2f9047b5eb969e0198b8a73e392642ca852ba786.tar.gz
spark-2f9047b5eb969e0198b8a73e392642ca852ba786.tar.bz2
spark-2f9047b5eb969e0198b8a73e392642ca852ba786.zip
[SPARK-15322][MLLIB][CORE][SQL] update deprecate accumulator usage into accumulatorV2 in spark project
## What changes were proposed in this pull request? I use Intellj-IDEA to search usage of deprecate SparkContext.accumulator in the whole spark project, and update the code.(except those test code for accumulator method itself) ## How was this patch tested? Exisiting unit tests Author: WeichenXu <WeichenXu123@outlook.com> Closes #13112 from WeichenXu123/update_accuV2_in_mllib.
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala8
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala8
7 files changed, 28 insertions, 26 deletions
diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
index 8cb0a295b0..58664e77d2 100644
--- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
@@ -65,9 +65,9 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim
test("foreachAsync") {
zeroPartRdd.foreachAsync(i => Unit).get()
- val accum = sc.accumulator(0)
+ val accum = sc.longAccumulator
sc.parallelize(1 to 1000, 3).foreachAsync { i =>
- accum += 1
+ accum.add(1)
}.get()
assert(accum.value === 1000)
}
@@ -75,9 +75,9 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim
test("foreachPartitionAsync") {
zeroPartRdd.foreachPartitionAsync(iter => Unit).get()
- val accum = sc.accumulator(0)
+ val accum = sc.longAccumulator
sc.parallelize(1 to 1000, 9).foreachPartitionAsync { iter =>
- accum += 1
+ accum.add(1)
}.get()
assert(accum.value === 9)
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala
index 1bcd85e1d5..acbcb0c4b7 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala
@@ -23,11 +23,12 @@ import java.nio.charset.Charset
import com.google.common.io.Files
-import org.apache.spark.{Accumulator, SparkConf, SparkContext}
+import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Seconds, StreamingContext, Time}
import org.apache.spark.util.IntParam
+import org.apache.spark.util.LongAccumulator
/**
* Use this singleton to get or register a Broadcast variable.
@@ -54,13 +55,13 @@ object WordBlacklist {
*/
object DroppedWordsCounter {
- @volatile private var instance: Accumulator[Long] = null
+ @volatile private var instance: LongAccumulator = null
- def getInstance(sc: SparkContext): Accumulator[Long] = {
+ def getInstance(sc: SparkContext): LongAccumulator = {
if (instance == null) {
synchronized {
if (instance == null) {
- instance = sc.accumulator(0L, "WordsInBlacklistCounter")
+ instance = sc.longAccumulator("WordsInBlacklistCounter")
}
}
}
@@ -124,7 +125,7 @@ object RecoverableNetworkWordCount {
// Use blacklist to drop words and use droppedWordsCounter to count them
val counts = rdd.filter { case (word, count) =>
if (blacklist.value.contains(word)) {
- droppedWordsCounter += count
+ droppedWordsCounter.add(count)
false
} else {
true
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala
index 8d4174124b..e79b1f3164 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala
@@ -19,7 +19,8 @@ package org.apache.spark.ml.util
import scala.collection.mutable
-import org.apache.spark.{Accumulator, SparkContext}
+import org.apache.spark.SparkContext
+import org.apache.spark.util.LongAccumulator;
/**
* Abstract class for stopwatches.
@@ -102,12 +103,12 @@ private[spark] class DistributedStopwatch(
sc: SparkContext,
override val name: String) extends Stopwatch {
- private val elapsedTime: Accumulator[Long] = sc.accumulator(0L, s"DistributedStopwatch($name)")
+ private val elapsedTime: LongAccumulator = sc.longAccumulator(s"DistributedStopwatch($name)")
override def elapsed(): Long = elapsedTime.value
override protected def add(duration: Long): Unit = {
- elapsedTime += duration
+ elapsedTime.add(duration)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 60f13d27d0..38728f2693 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -279,7 +279,7 @@ class KMeans private (
}
val activeCenters = activeRuns.map(r => centers(r)).toArray
- val costAccums = activeRuns.map(_ => sc.accumulator(0.0))
+ val costAccums = activeRuns.map(_ => sc.doubleAccumulator)
val bcActiveCenters = sc.broadcast(activeCenters)
@@ -296,7 +296,7 @@ class KMeans private (
points.foreach { point =>
(0 until runs).foreach { i =>
val (bestCenter, cost) = KMeans.findClosest(thisActiveCenters(i), point)
- costAccums(i) += cost
+ costAccums(i).add(cost)
val sum = sums(i)(bestCenter)
axpy(1.0, point.vector, sum)
counts(i)(bestCenter) += 1
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala
index 9e6bc7193c..141249a427 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala
@@ -60,9 +60,9 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext {
test("DistributedStopwatch on executors") {
val sw = new DistributedStopwatch(sc, "sw")
val rdd = sc.parallelize(0 until 4, 4)
- val acc = sc.accumulator(0L)
+ val acc = sc.longAccumulator
rdd.foreach { i =>
- acc += checkStopwatch(sw)
+ acc.add(checkStopwatch(sw))
}
assert(!sw.isRunning)
val elapsed = sw.elapsed()
@@ -88,12 +88,12 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(sw.toString ===
s"{\n local: ${localElapsed}ms,\n spark: ${sparkElapsed}ms\n}")
val rdd = sc.parallelize(0 until 4, 4)
- val acc = sc.accumulator(0L)
+ val acc = sc.longAccumulator
rdd.foreach { i =>
sw("local").start()
val duration = checkStopwatch(sw("spark"))
sw("local").stop()
- acc += duration
+ acc.add(duration)
}
val localElapsed2 = sw("local").elapsed()
assert(localElapsed2 === localElapsed)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index ff022b2dc4..a634502e2e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -62,15 +62,15 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
test("foreach") {
val ds = Seq(1, 2, 3).toDS()
- val acc = sparkContext.accumulator(0)
- ds.foreach(acc += _)
+ val acc = sparkContext.longAccumulator
+ ds.foreach(acc.add(_))
assert(acc.value == 6)
}
test("foreachPartition") {
val ds = Seq(1, 2, 3).toDS()
- val acc = sparkContext.accumulator(0)
- ds.foreachPartition(_.foreach(acc +=))
+ val acc = sparkContext.longAccumulator
+ ds.foreachPartition(_.foreach(acc.add(_)))
assert(acc.value == 6)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 3b9feae4a3..b02b714168 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -207,15 +207,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("foreach") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
- val acc = sparkContext.accumulator(0)
- ds.foreach(v => acc += v._2)
+ val acc = sparkContext.longAccumulator
+ ds.foreach(v => acc.add(v._2))
assert(acc.value == 6)
}
test("foreachPartition") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
- val acc = sparkContext.accumulator(0)
- ds.foreachPartition(_.foreach(v => acc += v._2))
+ val acc = sparkContext.longAccumulator
+ ds.foreachPartition(_.foreach(v => acc.add(v._2)))
assert(acc.value == 6)
}