aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/ContextCleaner.scala26
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala20
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala23
-rw-r--r--core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala10
4 files changed, 40 insertions, 39 deletions
diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
index 5a42299a0b..17014e4954 100644
--- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
@@ -18,9 +18,9 @@
package org.apache.spark
import java.lang.ref.{ReferenceQueue, WeakReference}
-import java.util.concurrent.{ScheduledExecutorService, TimeUnit}
+import java.util.concurrent.{ConcurrentLinkedQueue, ScheduledExecutorService, TimeUnit}
-import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
+import scala.collection.JavaConverters._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData}
@@ -57,13 +57,11 @@ private class CleanupTaskWeakReference(
*/
private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
- private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference]
- with SynchronizedBuffer[CleanupTaskWeakReference]
+ private val referenceBuffer = new ConcurrentLinkedQueue[CleanupTaskWeakReference]()
private val referenceQueue = new ReferenceQueue[AnyRef]
- private val listeners = new ArrayBuffer[CleanerListener]
- with SynchronizedBuffer[CleanerListener]
+ private val listeners = new ConcurrentLinkedQueue[CleanerListener]()
private val cleaningThread = new Thread() { override def run() { keepCleaning() }}
@@ -111,7 +109,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
/** Attach a listener object to get information of when objects are cleaned. */
def attachListener(listener: CleanerListener): Unit = {
- listeners += listener
+ listeners.add(listener)
}
/** Start the cleaner. */
@@ -166,7 +164,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
/** Register an object for cleanup. */
private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = {
- referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue)
+ referenceBuffer.add(new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue))
}
/** Keep cleaning RDD, shuffle, and broadcast state. */
@@ -179,7 +177,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
synchronized {
reference.map(_.task).foreach { task =>
logDebug("Got cleaning task " + task)
- referenceBuffer -= reference.get
+ referenceBuffer.remove(reference.get)
task match {
case CleanRDD(rddId) =>
doCleanupRDD(rddId, blocking = blockOnCleanupTasks)
@@ -206,7 +204,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
try {
logDebug("Cleaning RDD " + rddId)
sc.unpersistRDD(rddId, blocking)
- listeners.foreach(_.rddCleaned(rddId))
+ listeners.asScala.foreach(_.rddCleaned(rddId))
logInfo("Cleaned RDD " + rddId)
} catch {
case e: Exception => logError("Error cleaning RDD " + rddId, e)
@@ -219,7 +217,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
logDebug("Cleaning shuffle " + shuffleId)
mapOutputTrackerMaster.unregisterShuffle(shuffleId)
blockManagerMaster.removeShuffle(shuffleId, blocking)
- listeners.foreach(_.shuffleCleaned(shuffleId))
+ listeners.asScala.foreach(_.shuffleCleaned(shuffleId))
logInfo("Cleaned shuffle " + shuffleId)
} catch {
case e: Exception => logError("Error cleaning shuffle " + shuffleId, e)
@@ -231,7 +229,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
try {
logDebug(s"Cleaning broadcast $broadcastId")
broadcastManager.unbroadcast(broadcastId, true, blocking)
- listeners.foreach(_.broadcastCleaned(broadcastId))
+ listeners.asScala.foreach(_.broadcastCleaned(broadcastId))
logDebug(s"Cleaned broadcast $broadcastId")
} catch {
case e: Exception => logError("Error cleaning broadcast " + broadcastId, e)
@@ -243,7 +241,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
try {
logDebug("Cleaning accumulator " + accId)
Accumulators.remove(accId)
- listeners.foreach(_.accumCleaned(accId))
+ listeners.asScala.foreach(_.accumCleaned(accId))
logInfo("Cleaned accumulator " + accId)
} catch {
case e: Exception => logError("Error cleaning accumulator " + accId, e)
@@ -258,7 +256,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
try {
logDebug("Cleaning rdd checkpoint data " + rddId)
ReliableRDDCheckpointData.cleanCheckpoint(sc, rddId)
- listeners.foreach(_.checkpointCleaned(rddId))
+ listeners.asScala.foreach(_.checkpointCleaned(rddId))
logInfo("Cleaned rdd checkpoint data " + rddId)
}
catch {
diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala
index eb794b6739..658779360b 100644
--- a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala
@@ -17,7 +17,9 @@
package org.apache.spark.deploy.client
-import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
+import java.util.concurrent.ConcurrentLinkedQueue
+
+import scala.collection.JavaConverters._
import scala.concurrent.duration._
import org.scalatest.BeforeAndAfterAll
@@ -165,14 +167,14 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd
/** Application Listener to collect events */
private class AppClientCollector extends AppClientListener with Logging {
- val connectedIdList = new ArrayBuffer[String] with SynchronizedBuffer[String]
+ val connectedIdList = new ConcurrentLinkedQueue[String]()
@volatile var disconnectedCount: Int = 0
- val deadReasonList = new ArrayBuffer[String] with SynchronizedBuffer[String]
- val execAddedList = new ArrayBuffer[String] with SynchronizedBuffer[String]
- val execRemovedList = new ArrayBuffer[String] with SynchronizedBuffer[String]
+ val deadReasonList = new ConcurrentLinkedQueue[String]()
+ val execAddedList = new ConcurrentLinkedQueue[String]()
+ val execRemovedList = new ConcurrentLinkedQueue[String]()
def connected(id: String): Unit = {
- connectedIdList += id
+ connectedIdList.add(id)
}
def disconnected(): Unit = {
@@ -182,7 +184,7 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd
}
def dead(reason: String): Unit = {
- deadReasonList += reason
+ deadReasonList.add(reason)
}
def executorAdded(
@@ -191,11 +193,11 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd
hostPort: String,
cores: Int,
memory: Int): Unit = {
- execAddedList += id
+ execAddedList.add(id)
}
def executorRemoved(id: String, message: String, exitStatus: Option[Int]): Unit = {
- execRemovedList += id
+ execRemovedList.add(id)
}
}
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index 6f4eda8b47..2204800388 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -20,9 +20,10 @@ package org.apache.spark.rpc
import java.io.{File, NotSerializableException}
import java.nio.charset.StandardCharsets.UTF_8
import java.util.UUID
-import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit}
+import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch, TimeoutException, TimeUnit}
import scala.collection.mutable
+import scala.collection.JavaConverters._
import scala.concurrent.Await
import scala.concurrent.duration._
import scala.language.postfixOps
@@ -490,30 +491,30 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
/**
* Setup an [[RpcEndpoint]] to collect all network events.
- * @return the [[RpcEndpointRef]] and an `Seq` that contains network events.
+ * @return the [[RpcEndpointRef]] and an `ConcurrentLinkedQueue` that contains network events.
*/
private def setupNetworkEndpoint(
_env: RpcEnv,
- name: String): (RpcEndpointRef, Seq[(Any, Any)]) = {
- val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)]
+ name: String): (RpcEndpointRef, ConcurrentLinkedQueue[(Any, Any)]) = {
+ val events = new ConcurrentLinkedQueue[(Any, Any)]
val ref = _env.setupEndpoint("network-events-non-client", new ThreadSafeRpcEndpoint {
override val rpcEnv = _env
override def receive: PartialFunction[Any, Unit] = {
case "hello" =>
- case m => events += "receive" -> m
+ case m => events.add("receive" -> m)
}
override def onConnected(remoteAddress: RpcAddress): Unit = {
- events += "onConnected" -> remoteAddress
+ events.add("onConnected" -> remoteAddress)
}
override def onDisconnected(remoteAddress: RpcAddress): Unit = {
- events += "onDisconnected" -> remoteAddress
+ events.add("onDisconnected" -> remoteAddress)
}
override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
- events += "onNetworkError" -> remoteAddress
+ events.add("onNetworkError" -> remoteAddress)
}
})
@@ -560,7 +561,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
eventually(timeout(5 seconds), interval(5 millis)) {
// We don't know the exact client address but at least we can verify the message type
- assert(events.map(_._1).contains("onConnected"))
+ assert(events.asScala.map(_._1).exists(_ == "onConnected"))
}
clientEnv.shutdown()
@@ -568,8 +569,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
eventually(timeout(5 seconds), interval(5 millis)) {
// We don't know the exact client address but at least we can verify the message type
- assert(events.map(_._1).contains("onConnected"))
- assert(events.map(_._1).contains("onDisconnected"))
+ assert(events.asScala.map(_._1).exists(_ == "onConnected"))
+ assert(events.asScala.map(_._1).exists(_ == "onDisconnected"))
}
} finally {
clientEnv.shutdown()
diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala
index b207d497f3..6f7dddd4f7 100644
--- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala
@@ -17,9 +17,9 @@
package org.apache.spark.util
-import java.util.concurrent.CountDownLatch
+import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch}
-import scala.collection.mutable
+import scala.collection.JavaConverters._
import scala.concurrent.duration._
import scala.language.postfixOps
@@ -31,11 +31,11 @@ import org.apache.spark.SparkFunSuite
class EventLoopSuite extends SparkFunSuite with Timeouts {
test("EventLoop") {
- val buffer = new mutable.ArrayBuffer[Int] with mutable.SynchronizedBuffer[Int]
+ val buffer = new ConcurrentLinkedQueue[Int]
val eventLoop = new EventLoop[Int]("test") {
override def onReceive(event: Int): Unit = {
- buffer += event
+ buffer.add(event)
}
override def onError(e: Throwable): Unit = {}
@@ -43,7 +43,7 @@ class EventLoopSuite extends SparkFunSuite with Timeouts {
eventLoop.start()
(1 to 100).foreach(eventLoop.post)
eventually(timeout(5 seconds), interval(5 millis)) {
- assert((1 to 100) === buffer.toSeq)
+ assert((1 to 100) === buffer.asScala.toSeq)
}
eventLoop.stop()
}