aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala24
-rw-r--r--core/src/main/scala/spark/broadcast/MultiTracker.scala2
-rw-r--r--core/src/main/scala/spark/broadcast/SourceInfo.scala5
-rw-r--r--core/src/main/scala/spark/broadcast/TreeBroadcast.scala20
-rw-r--r--examples/src/main/scala/spark/examples/BroadcastTest.scala10
5 files changed, 35 insertions, 26 deletions
diff --git a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
index cf20f456c4..ef27bbb502 100644
--- a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
@@ -311,9 +311,11 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots)
while (hasBlocks.get < totalBlocks) {
- var numThreadsToCreate =
- math.min(listOfSources.size, MultiTracker.MaxChatSlots) -
+ var numThreadsToCreate = 0
+ listOfSources.synchronized {
+ numThreadsToCreate = math.min(listOfSources.size, MultiTracker.MaxChatSlots) -
threadPool.getActiveCount
+ }
while (hasBlocks.get < totalBlocks && numThreadsToCreate > 0) {
var peerToTalkTo = pickPeerToTalkToRandom
@@ -726,7 +728,6 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
guidePortLock.synchronized { guidePortLock.notifyAll() }
try {
- // Don't stop until there is a copy in HDFS
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
@@ -734,14 +735,17 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
clientSocket = serverSocket.accept()
} catch {
case e: Exception => {
- logError("GuideMultipleRequests Timeout.")
-
// Stop broadcast if at least one worker has connected and
// everyone connected so far are done. Comparing with
// listOfSources.size - 1, because it includes the Guide itself
- if (listOfSources.size > 1 &&
- setOfCompletedSources.size == listOfSources.size - 1) {
- stopBroadcast = true
+ listOfSources.synchronized {
+ setOfCompletedSources.synchronized {
+ if (listOfSources.size > 1 &&
+ setOfCompletedSources.size == listOfSources.size - 1) {
+ stopBroadcast = true
+ logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.")
+ }
+ }
}
}
}
@@ -922,9 +926,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
clientSocket = serverSocket.accept()
} catch {
- case e: Exception => {
- logError("ServeMultipleRequests Timeout.")
- }
+ case e: Exception => { }
}
if (clientSocket != null) {
logDebug("Serve: Accepted new client connection:" + clientSocket)
diff --git a/core/src/main/scala/spark/broadcast/MultiTracker.scala b/core/src/main/scala/spark/broadcast/MultiTracker.scala
index dd8e6dd246..5e76dedb94 100644
--- a/core/src/main/scala/spark/broadcast/MultiTracker.scala
+++ b/core/src/main/scala/spark/broadcast/MultiTracker.scala
@@ -228,7 +228,7 @@ extends Logging {
var oosTracker: ObjectOutputStream = null
var oisTracker: ObjectInputStream = null
- var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxOverGoToDefault)
+ var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxNotStartedRetry)
var retriesLeft = MultiTracker.MaxRetryCount
do {
diff --git a/core/src/main/scala/spark/broadcast/SourceInfo.scala b/core/src/main/scala/spark/broadcast/SourceInfo.scala
index 705dd6fd81..c79bb93c38 100644
--- a/core/src/main/scala/spark/broadcast/SourceInfo.scala
+++ b/core/src/main/scala/spark/broadcast/SourceInfo.scala
@@ -27,9 +27,10 @@ extends Comparable[SourceInfo] with Logging {
* Helper Object of SourceInfo for its constants
*/
private[spark] object SourceInfo {
- // Constants for special values of listenPort
+ // Broadcast has not started yet! Should never happen.
val TxNotStartedRetry = -1
- val TxOverGoToDefault = 0
+ // Broadcast has already finished. Try default mechanism.
+ val TxOverGoToDefault = -3
// Other constants
val StopBroadcast = -2
val UnusedParam = 0
diff --git a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala
index 5bd40a40e3..fa676e9064 100644
--- a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala
@@ -292,15 +292,17 @@ extends Broadcast[T](id) with Logging with Serializable {
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
- logError("GuideMultipleRequests Timeout.")
-
// Stop broadcast if at least one worker has connected and
- // everyone connected so far are done.
- // Comparing with listOfSources.size - 1, because the Guide itself
- // is included
- if (listOfSources.size > 1 &&
- setOfCompletedSources.size == listOfSources.size - 1) {
- stopBroadcast = true
+ // everyone connected so far are done. Comparing with
+ // listOfSources.size - 1, because it includes the Guide itself
+ listOfSources.synchronized {
+ setOfCompletedSources.synchronized {
+ if (listOfSources.size > 1 &&
+ setOfCompletedSources.size == listOfSources.size - 1) {
+ stopBroadcast = true
+ logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.")
+ }
+ }
}
}
}
@@ -492,7 +494,7 @@ extends Broadcast[T](id) with Logging with Serializable {
serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
- case e: Exception => logError("ServeMultipleRequests Timeout.")
+ case e: Exception => { }
}
if (clientSocket != null) {
diff --git a/examples/src/main/scala/spark/examples/BroadcastTest.scala b/examples/src/main/scala/spark/examples/BroadcastTest.scala
index 08be49a41c..230097c7db 100644
--- a/examples/src/main/scala/spark/examples/BroadcastTest.scala
+++ b/examples/src/main/scala/spark/examples/BroadcastTest.scala
@@ -17,9 +17,13 @@ object BroadcastTest {
for (i <- 0 until arr1.length)
arr1(i) = i
- val barr1 = spark.broadcast(arr1)
- spark.parallelize(1 to 10, slices).foreach {
- i => println(barr1.value.size)
+ for (i <- 0 until 2) {
+ println("Iteration " + i)
+ println("===========")
+ val barr1 = spark.broadcast(arr1)
+ spark.parallelize(1 to 10, slices).foreach {
+ i => println(barr1.value.size)
+ }
}
System.exit(0)