diff options
4 files changed, 51 insertions, 32 deletions
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index f2d9499bad..4cdb9710ec 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -509,10 +509,15 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m * Replicate block to another node. */ + var firstTime = true + var peers : Seq[BlockManagerId] = null private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) { val tLevel: StorageLevel = new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) - var peers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1)) + if (firstTime) { + peers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1)) + firstTime = false; + } for (peer: BlockManagerId <- peers) { val start = System.nanoTime data.rewind() diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index acf97c1883..9f9001e4d5 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -4,6 +4,7 @@ import spark.Logging import spark.SparkEnv import scala.collection.mutable.HashMap +import scala.collection.mutable.Queue import akka.actor._ import akka.pattern.ask @@ -28,6 +29,17 @@ extends Logging { logInfo("Registered receiver for network stream " + streamId) sender ! true } + case GotBlockIds(streamId, blockIds) => { + val tmp = receivedBlockIds.synchronized { + if (!receivedBlockIds.contains(streamId)) { + receivedBlockIds += ((streamId, new Queue[String])) + } + receivedBlockIds(streamId) + } + tmp.synchronized { + tmp ++= blockIds + } + } } } @@ -69,8 +81,8 @@ extends Logging { val networkInputStreamIds = networkInputStreams.map(_.id).toArray val receiverExecutor = new ReceiverExecutor() val receiverInfo = new HashMap[Int, ActorRef] - val receivedBlockIds = new HashMap[Int, Array[String]] - val timeout = 1000.milliseconds + val receivedBlockIds = new HashMap[Int, Queue[String]] + val timeout = 5000.milliseconds var currentTime: Time = null @@ -86,22 +98,12 @@ extends Logging { } def getBlockIds(receiverId: Int, time: Time): Array[String] = synchronized { - if (currentTime == null || time > currentTime) { - logInfo("Getting block ids from receivers for " + time) - implicit val ec = ssc.env.actorSystem.dispatcher - receivedBlockIds.clear() - val message = new GetBlockIds(time) - val listOfFutures = receiverInfo.values.map( - _.ask(message)(timeout).mapTo[GotBlockIds] - ).toList - val futureOfList = Future.sequence(listOfFutures) - val allBlockIds = Await.result(futureOfList, timeout) - receivedBlockIds ++= allBlockIds.map(x => (x.streamId, x.blocksIds)) - if (receivedBlockIds.size != receiverInfo.size) { - throw new Exception("Unexpected number of the Block IDs received") - } - currentTime = time + val queue = receivedBlockIds.synchronized { + receivedBlockIds.getOrElse(receiverId, new Queue[String]()) + } + val result = queue.synchronized { + queue.dequeueAll(x => true) } - receivedBlockIds.getOrElse(receiverId, Array[String]()) + result.toArray } -}
\ No newline at end of file +} diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala index d59c245a23..d29aea7886 100644 --- a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala @@ -86,14 +86,15 @@ class RawInputDStream[T: ClassManifest]( private class ReceiverActor(env: SparkEnv, receivingThread: Thread) extends Actor { val newBlocks = new ArrayBuffer[String] + logInfo("Attempting to register with tracker") + val ip = System.getProperty("spark.master.host", "localhost") + val port = System.getProperty("spark.master.port", "7077").toInt + val actorName: String = "NetworkInputTracker" + val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) + val trackerActor = env.actorSystem.actorFor(url) + val timeout = 5.seconds + override def preStart() { - logInfo("Attempting to register with tracker") - val ip = System.getProperty("spark.master.host", "localhost") - val port = System.getProperty("spark.master.port", "7077").toInt - val actorName: String = "NetworkInputTracker" - val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) - val trackerActor = env.actorSystem.actorFor(url) - val timeout = 1.seconds val future = trackerActor.ask(RegisterReceiver(streamId, self))(timeout) Await.result(future, timeout) } @@ -101,6 +102,7 @@ class RawInputDStream[T: ClassManifest]( override def receive = { case BlockPublished(blockId) => newBlocks += blockId + val future = trackerActor ! GotBlockIds(streamId, Array(blockId)) case GetBlockIds(time) => logInfo("Got request for block IDs for " + time) @@ -111,5 +113,6 @@ class RawInputDStream[T: ClassManifest]( receivingThread.interrupt() sender ! true } + } } diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala index 298d9ef381..9702003805 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala @@ -1,11 +1,24 @@ package spark.streaming.examples import spark.util.IntParam +import spark.SparkContext +import spark.SparkContext._ import spark.storage.StorageLevel import spark.streaming._ import spark.streaming.StreamingContext._ +import WordCount2_ExtraFunctions._ + object WordCountRaw { + def moreWarmup(sc: SparkContext) { + (0 until 40).foreach {i => + sc.parallelize(1 to 20000000, 1000) + .map(_ % 1331).map(_.toString) + .mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) + .collect() + } + } + def main(args: Array[String]) { if (args.length != 7) { System.err.println("Usage: WordCountRaw <master> <streams> <host> <port> <batchMs> <chkptMs> <reduces>") @@ -20,16 +33,12 @@ object WordCountRaw { ssc.setBatchDuration(Milliseconds(batchMs)) // Make sure some tasks have started on each node - ssc.sc.parallelize(1 to 1000, 1000).count() - ssc.sc.parallelize(1 to 1000, 1000).count() - ssc.sc.parallelize(1 to 1000, 1000).count() + moreWarmup(ssc.sc) val rawStreams = (1 to streams).map(_ => ssc.createRawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_2)).toArray val union = new UnifiedDStream(rawStreams) - import WordCount2_ExtraFunctions._ - val windowedCounts = union.mapPartitions(splitAndCountPartitions) .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) windowedCounts.persist(StorageLevel.MEMORY_ONLY_DESER, StorageLevel.MEMORY_ONLY_DESER_2, |