From 802aa8aef90fe7d2f0c859c68f12361db286bf20 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 1 Oct 2012 15:20:42 -0700 Subject: Some bug fixes and logging fixes for broadcast. --- core/src/main/scala/spark/RDD.scala | 2 +- .../spark/broadcast/BitTorrentBroadcast.scala | 49 +++++++++++--------- .../src/main/scala/spark/broadcast/Broadcast.scala | 22 +++------ .../scala/spark/broadcast/BroadcastFactory.scala | 2 +- .../main/scala/spark/broadcast/HttpBroadcast.scala | 39 ++++++++-------- .../main/scala/spark/broadcast/MultiTracker.scala | 47 ++++++++++--------- .../main/scala/spark/broadcast/TreeBroadcast.scala | 53 ++++++++++++---------- .../main/scala/spark/storage/BlockManager.scala | 2 - .../src/main/scala/spark/storage/MemoryStore.scala | 2 + 9 files changed, 111 insertions(+), 107 deletions(-) (limited to 'core/src/main') diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index ab8014c056..351c3d9d0b 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -99,7 +99,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial def getStorageLevel = storageLevel - def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = { + private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = { if (!level.useDisk && level.replication < 2) { throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")") } diff --git a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala index 0b9647d168..b72e8986d3 100644 --- a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala @@ -11,14 +11,17 @@ import scala.math import spark._ import spark.storage.StorageLevel -class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean) -extends Broadcast[T] with Logging with Serializable { +class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) + extends Broadcast[T](id) + with Logging + with Serializable { def value = value_ + def blockId: String = "broadcast_" + id + MultiTracker.synchronized { - SparkEnv.get.blockManager.putSingle( - uuid.toString, value_, StorageLevel.MEMORY_AND_DISK, false) + SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) } @transient var arrayOfBlocks: Array[BroadcastBlock] = null @@ -45,7 +48,7 @@ extends Broadcast[T] with Logging with Serializable { // Used only in Workers @transient var ttGuide: TalkToGuide = null - @transient var hostAddress = Utils.localIpAddress + @transient var hostAddress = Utils.localIpAddress() @transient var listenPort = -1 @transient var guidePort = -1 @@ -106,17 +109,19 @@ extends Broadcast[T] with Logging with Serializable { listOfSources += masterSource // Register with the Tracker - MultiTracker.registerBroadcast(uuid, + MultiTracker.registerBroadcast(id, SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes)) } private def readObject(in: ObjectInputStream) { in.defaultReadObject() MultiTracker.synchronized { - SparkEnv.get.blockManager.getSingle(uuid.toString) match { - case Some(x) => x.asInstanceOf[T] - case None => { - logInfo("Started reading broadcast variable " + uuid) + SparkEnv.get.blockManager.getSingle(blockId) match { + case Some(x) => + value_ = x.asInstanceOf[T] + + case None => + logInfo("Started reading broadcast variable " + id) // Initializing everything because Master will only send null/0 values // Only the 1st worker in a node can be here. Others will get from cache initializeWorkerVariables() @@ -131,18 +136,17 @@ extends Broadcast[T] with Logging with Serializable { val start = System.nanoTime - val receptionSucceeded = receiveBroadcast(uuid) + val receptionSucceeded = receiveBroadcast(id) if (receptionSucceeded) { value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) SparkEnv.get.blockManager.putSingle( - uuid.toString, value_, StorageLevel.MEMORY_AND_DISK, false) + blockId, value_, StorageLevel.MEMORY_AND_DISK, false) } else { - logError("Reading Broadcasted variable " + uuid + " failed") + logError("Reading broadcast variable " + id + " failed") } val time = (System.nanoTime - start) / 1e9 - logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s") - } + logInfo("Reading broadcast variable " + id + " took " + time + " s") } } } @@ -254,8 +258,8 @@ extends Broadcast[T] with Logging with Serializable { } } - def receiveBroadcast(variableUUID: UUID): Boolean = { - val gInfo = MultiTracker.getGuideInfo(variableUUID) + def receiveBroadcast(variableID: Long): Boolean = { + val gInfo = MultiTracker.getGuideInfo(variableID) if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) { return false @@ -760,7 +764,7 @@ extends Broadcast[T] with Logging with Serializable { logInfo("Sending stopBroadcast notifications...") sendStopBroadcastNotifications - MultiTracker.unregisterBroadcast(uuid) + MultiTracker.unregisterBroadcast(id) } finally { if (serverSocket != null) { logInfo("GuideMultipleRequests now stopping...") @@ -1025,7 +1029,10 @@ extends Broadcast[T] with Logging with Serializable { class BitTorrentBroadcastFactory extends BroadcastFactory { - def initialize(isMaster: Boolean) = MultiTracker.initialize(isMaster) - def newBroadcast[T](value_ : T, isLocal: Boolean) = new BitTorrentBroadcast[T](value_, isLocal) - def stop() = MultiTracker.stop + def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) } + + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new BitTorrentBroadcast[T](value_, isLocal, id) + + def stop() { MultiTracker.stop() } } diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala index d68e56a114..3ba91c93e9 100644 --- a/core/src/main/scala/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/spark/broadcast/Broadcast.scala @@ -1,23 +1,17 @@ package spark.broadcast import java.io._ -import java.net._ -import java.util.{BitSet, UUID} -import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} - -import scala.collection.mutable.Map +import java.util.concurrent.atomic.AtomicLong import spark._ -trait Broadcast[T] extends Serializable { - val uuid = UUID.randomUUID - +abstract class Broadcast[T](id: Long) extends Serializable { def value: T // We cannot have an abstract readObject here due to some weird issues with // readObject having to be 'private' in sub-classes. - override def toString = "spark.Broadcast(" + uuid + ")" + override def toString = "spark.Broadcast(" + id + ")" } class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializable { @@ -49,14 +43,10 @@ class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializabl broadcastFactory.stop() } - private def getBroadcastFactory: BroadcastFactory = { - if (broadcastFactory == null) { - throw new SparkException ("Broadcast.getBroadcastFactory called before initialize") - } - broadcastFactory - } + private val nextBroadcastId = new AtomicLong(0) - def newBroadcast[T](value_ : T, isLocal: Boolean) = broadcastFactory.newBroadcast[T](value_, isLocal) + def newBroadcast[T](value_ : T, isLocal: Boolean) = + broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) def isMaster = isMaster_ } diff --git a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala index e341d556bf..66ca8d56d5 100644 --- a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala @@ -8,6 +8,6 @@ package spark.broadcast */ trait BroadcastFactory { def initialize(isMaster: Boolean): Unit - def newBroadcast[T](value_ : T, isLocal: Boolean): Broadcast[T] + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] def stop(): Unit } diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala index f5f2b3dbf2..d8cf5e37d4 100644 --- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala @@ -12,34 +12,34 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream import spark._ import spark.storage.StorageLevel -class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean) -extends Broadcast[T] with Logging with Serializable { +class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) +extends Broadcast[T](id) with Logging with Serializable { def value = value_ + def blockId: String = "broadcast_" + id + HttpBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle( - uuid.toString, value_, StorageLevel.MEMORY_AND_DISK, false) + SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) } if (!isLocal) { - HttpBroadcast.write(uuid, value_) + HttpBroadcast.write(id, value_) } // Called by JVM when deserializing an object private def readObject(in: ObjectInputStream) { in.defaultReadObject() HttpBroadcast.synchronized { - SparkEnv.get.blockManager.getSingle(uuid.toString) match { + SparkEnv.get.blockManager.getSingle(blockId) match { case Some(x) => value_ = x.asInstanceOf[T] case None => { - logInfo("Started reading broadcast variable " + uuid) + logInfo("Started reading broadcast variable " + id) val start = System.nanoTime - value_ = HttpBroadcast.read[T](uuid) - SparkEnv.get.blockManager.putSingle( - uuid.toString, value_, StorageLevel.MEMORY_AND_DISK, false) + value_ = HttpBroadcast.read[T](id) + SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) val time = (System.nanoTime - start) / 1e9 - logInfo("Reading broadcast variable " + uuid + " took " + time + " s") + logInfo("Reading broadcast variable " + id + " took " + time + " s") } } } @@ -47,9 +47,12 @@ extends Broadcast[T] with Logging with Serializable { } class HttpBroadcastFactory extends BroadcastFactory { - def initialize(isMaster: Boolean) = HttpBroadcast.initialize(isMaster) - def newBroadcast[T](value_ : T, isLocal: Boolean) = new HttpBroadcast[T](value_, isLocal) - def stop() = HttpBroadcast.stop() + def initialize(isMaster: Boolean) { HttpBroadcast.initialize(isMaster) } + + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new HttpBroadcast[T](value_, isLocal, id) + + def stop() { HttpBroadcast.stop() } } private object HttpBroadcast extends Logging { @@ -94,8 +97,8 @@ private object HttpBroadcast extends Logging { logInfo("Broadcast server started at " + serverUri) } - def write(uuid: UUID, value: Any) { - val file = new File(broadcastDir, "broadcast-" + uuid) + def write(id: Long, value: Any) { + val file = new File(broadcastDir, "broadcast-" + id) val out: OutputStream = if (compress) { new LZFOutputStream(new FileOutputStream(file)) // Does its own buffering } else { @@ -107,8 +110,8 @@ private object HttpBroadcast extends Logging { serOut.close() } - def read[T](uuid: UUID): T = { - val url = serverUri + "/broadcast-" + uuid + def read[T](id: Long): T = { + val url = serverUri + "/broadcast-" + id var in = if (compress) { new LZFInputStream(new URL(url).openStream()) // Does its own buffering } else { diff --git a/core/src/main/scala/spark/broadcast/MultiTracker.scala b/core/src/main/scala/spark/broadcast/MultiTracker.scala index d5f5b22461..d00db23362 100644 --- a/core/src/main/scala/spark/broadcast/MultiTracker.scala +++ b/core/src/main/scala/spark/broadcast/MultiTracker.scala @@ -2,8 +2,7 @@ package spark.broadcast import java.io._ import java.net._ -import java.util.{UUID, Random} -import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} +import java.util.Random import scala.collection.mutable.Map @@ -18,7 +17,7 @@ extends Logging { val FIND_BROADCAST_TRACKER = 2 // Map to keep track of guides of ongoing broadcasts - var valueToGuideMap = Map[UUID, SourceInfo]() + var valueToGuideMap = Map[Long, SourceInfo]() // Random number generator var ranGen = new Random @@ -154,44 +153,44 @@ extends Logging { val messageType = ois.readObject.asInstanceOf[Int] if (messageType == REGISTER_BROADCAST_TRACKER) { - // Receive UUID - val uuid = ois.readObject.asInstanceOf[UUID] + // Receive Long + val id = ois.readObject.asInstanceOf[Long] // Receive hostAddress and listenPort val gInfo = ois.readObject.asInstanceOf[SourceInfo] // Add to the map valueToGuideMap.synchronized { - valueToGuideMap += (uuid -> gInfo) + valueToGuideMap += (id -> gInfo) } - logInfo ("New broadcast " + uuid + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap) + logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap) // Send dummy ACK oos.writeObject(-1) oos.flush() } else if (messageType == UNREGISTER_BROADCAST_TRACKER) { - // Receive UUID - val uuid = ois.readObject.asInstanceOf[UUID] + // Receive Long + val id = ois.readObject.asInstanceOf[Long] // Remove from the map valueToGuideMap.synchronized { - valueToGuideMap(uuid) = SourceInfo("", SourceInfo.TxOverGoToDefault) + valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault) } - logInfo ("Broadcast " + uuid + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap) + logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap) // Send dummy ACK oos.writeObject(-1) oos.flush() } else if (messageType == FIND_BROADCAST_TRACKER) { - // Receive UUID - val uuid = ois.readObject.asInstanceOf[UUID] + // Receive Long + val id = ois.readObject.asInstanceOf[Long] var gInfo = - if (valueToGuideMap.contains(uuid)) valueToGuideMap(uuid) + if (valueToGuideMap.contains(id)) valueToGuideMap(id) else SourceInfo("", SourceInfo.TxNotStartedRetry) - logDebug("Got new request: " + clientSocket + " for " + uuid + " : " + gInfo.listenPort) + logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort) // Send reply back oos.writeObject(gInfo) @@ -224,7 +223,7 @@ extends Logging { } } - def getGuideInfo(variableUUID: UUID): SourceInfo = { + def getGuideInfo(variableLong: Long): SourceInfo = { var clientSocketToTracker: Socket = null var oosTracker: ObjectOutputStream = null var oisTracker: ObjectInputStream = null @@ -247,8 +246,8 @@ extends Logging { oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER) oosTracker.flush() - // Send UUID and receive GuideInfo - oosTracker.writeObject(variableUUID) + // Send Long and receive GuideInfo + oosTracker.writeObject(variableLong) oosTracker.flush() gInfo = oisTracker.readObject.asInstanceOf[SourceInfo] } catch { @@ -276,7 +275,7 @@ extends Logging { return gInfo } - def registerBroadcast(uuid: UUID, gInfo: SourceInfo) { + def registerBroadcast(id: Long, gInfo: SourceInfo) { val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort) val oosST = new ObjectOutputStream(socket.getOutputStream) oosST.flush() @@ -286,8 +285,8 @@ extends Logging { oosST.writeObject(REGISTER_BROADCAST_TRACKER) oosST.flush() - // Send UUID of this broadcast - oosST.writeObject(uuid) + // Send Long of this broadcast + oosST.writeObject(id) oosST.flush() // Send this tracker's information @@ -303,7 +302,7 @@ extends Logging { socket.close() } - def unregisterBroadcast(uuid: UUID) { + def unregisterBroadcast(id: Long) { val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort) val oosST = new ObjectOutputStream(socket.getOutputStream) oosST.flush() @@ -313,8 +312,8 @@ extends Logging { oosST.writeObject(UNREGISTER_BROADCAST_TRACKER) oosST.flush() - // Send UUID of this broadcast - oosST.writeObject(uuid) + // Send Long of this broadcast + oosST.writeObject(id) oosST.flush() // Receive ACK and throw it away diff --git a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala index 574477a5fc..c1148b22ca 100644 --- a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala @@ -10,14 +10,15 @@ import scala.math import spark._ import spark.storage.StorageLevel -class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean) -extends Broadcast[T] with Logging with Serializable { +class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) +extends Broadcast[T](id) with Logging with Serializable { def value = value_ + def blockId = "broadcast_" + id + MultiTracker.synchronized { - SparkEnv.get.blockManager.putSingle( - uuid.toString, value_, StorageLevel.MEMORY_AND_DISK, false) + SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) } @transient var arrayOfBlocks: Array[BroadcastBlock] = null @@ -35,7 +36,7 @@ extends Broadcast[T] with Logging with Serializable { @transient var serveMR: ServeMultipleRequests = null @transient var guideMR: GuideMultipleRequests = null - @transient var hostAddress = Utils.localIpAddress + @transient var hostAddress = Utils.localIpAddress() @transient var listenPort = -1 @transient var guidePort = -1 @@ -43,7 +44,7 @@ extends Broadcast[T] with Logging with Serializable { // Must call this after all the variables have been created/initialized if (!isLocal) { - sendBroadcast + sendBroadcast() } def sendBroadcast() { @@ -84,20 +85,22 @@ extends Broadcast[T] with Logging with Serializable { listOfSources += masterSource // Register with the Tracker - MultiTracker.registerBroadcast(uuid, + MultiTracker.registerBroadcast(id, SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes)) } private def readObject(in: ObjectInputStream) { in.defaultReadObject() MultiTracker.synchronized { - SparkEnv.get.blockManager.getSingle(uuid.toString) match { - case Some(x) => x.asInstanceOf[T] - case None => { - logInfo("Started reading broadcast variable " + uuid) + SparkEnv.get.blockManager.getSingle(blockId) match { + case Some(x) => + value_ = x.asInstanceOf[T] + + case None => + logInfo("Started reading broadcast variable " + id) // Initializing everything because Master will only send null/0 values // Only the 1st worker in a node can be here. Others will get from cache - initializeWorkerVariables + initializeWorkerVariables() logInfo("Local host address: " + hostAddress) @@ -108,18 +111,17 @@ extends Broadcast[T] with Logging with Serializable { val start = System.nanoTime - val receptionSucceeded = receiveBroadcast(uuid) + val receptionSucceeded = receiveBroadcast(id) if (receptionSucceeded) { value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) SparkEnv.get.blockManager.putSingle( - uuid.toString, value_, StorageLevel.MEMORY_AND_DISK, false) + blockId, value_, StorageLevel.MEMORY_AND_DISK, false) } else { - logError("Reading Broadcasted variable " + uuid + " failed") + logError("Reading broadcast variable " + id + " failed") } val time = (System.nanoTime - start) / 1e9 - logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s") - } + logInfo("Reading broadcast variable " + id + " took " + time + " s") } } } @@ -136,14 +138,14 @@ extends Broadcast[T] with Logging with Serializable { serveMR = null - hostAddress = Utils.localIpAddress + hostAddress = Utils.localIpAddress() listenPort = -1 stopBroadcast = false } - def receiveBroadcast(variableUUID: UUID): Boolean = { - val gInfo = MultiTracker.getGuideInfo(variableUUID) + def receiveBroadcast(variableID: Long): Boolean = { + val gInfo = MultiTracker.getGuideInfo(variableID) if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) { return false @@ -316,7 +318,7 @@ extends Broadcast[T] with Logging with Serializable { logInfo("Sending stopBroadcast notifications...") sendStopBroadcastNotifications - MultiTracker.unregisterBroadcast(uuid) + MultiTracker.unregisterBroadcast(id) } finally { if (serverSocket != null) { logInfo("GuideMultipleRequests now stopping...") @@ -572,7 +574,10 @@ extends Broadcast[T] with Logging with Serializable { class TreeBroadcastFactory extends BroadcastFactory { - def initialize(isMaster: Boolean) = MultiTracker.initialize(isMaster) - def newBroadcast[T](value_ : T, isLocal: Boolean) = new TreeBroadcast[T](value_, isLocal) - def stop() = MultiTracker.stop + def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) } + + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new TreeBroadcast[T](value_, isLocal, id) + + def stop() { MultiTracker.stop() } } diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index e7dea904c3..8be2d08cfc 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -82,8 +82,6 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m val compress = System.getProperty("spark.blockManager.compress", "false").toBoolean - initLogging() - initialize() /** diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala index 93520dc4ff..ea6f3c4fcc 100644 --- a/core/src/main/scala/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/spark/storage/MemoryStore.scala @@ -147,6 +147,8 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) logInfo("MemoryStore cleared") } + // TODO: This should be able to return false if the space is larger than our total memory, + // or if adding this block would require evicting another one from the same RDD private def ensureFreeSpace(space: Long) { logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format( space, currentMemory, maxMemory)) -- cgit v1.2.3